199 lines
7.8 KiB
Python
199 lines
7.8 KiB
Python
import argparse
|
||
import torch
|
||
import cv2
|
||
import os
|
||
import numpy as np
|
||
from PIL import Image
|
||
from sam3 import build_sam3_image_model
|
||
from sam3.model.sam3_image_processor import Sam3Processor
|
||
|
||
# 支持的图片扩展名
|
||
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff'}
|
||
|
||
def process_one_image(processor, img_path, args, output_dir):
|
||
filename = os.path.basename(img_path)
|
||
print(f"\n[INFO] 正在处理: {filename} ...")
|
||
|
||
# --- 1. 读取图片 ---
|
||
try:
|
||
# 读取原始图片用于推理
|
||
pil_image = Image.open(img_path).convert("RGB")
|
||
# 读取 OpenCV 格式用于后续处理和最终输出背景
|
||
orig_cv2_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||
except Exception as e:
|
||
print(f"[Error] 无法读取图片 {filename}: {e}")
|
||
return
|
||
|
||
# --- 2. 推理 (逻辑不变) ---
|
||
try:
|
||
inference_state = processor.set_image(pil_image)
|
||
results = processor.set_text_prompt(
|
||
state=inference_state,
|
||
prompt=args.text
|
||
)
|
||
except Exception as e:
|
||
print(f"[Error] 推理报错 ({filename}): {e}")
|
||
return
|
||
|
||
# --- 3. 结果过滤 (逻辑不变) ---
|
||
masks = results.get("masks")
|
||
scores = results.get("scores")
|
||
boxes = results.get("boxes")
|
||
|
||
if masks is None:
|
||
print(f"[WARN] {filename} 未检测到任何目标。")
|
||
return
|
||
|
||
if isinstance(masks, torch.Tensor): masks = masks.cpu().numpy()
|
||
if isinstance(scores, torch.Tensor): scores = scores.cpu().numpy()
|
||
if isinstance(boxes, torch.Tensor): boxes = boxes.cpu().numpy()
|
||
|
||
keep_indices = np.where(scores > args.conf)[0]
|
||
if len(keep_indices) == 0:
|
||
print(f"[WARN] {filename} 没有目标超过阈值 {args.conf}。")
|
||
return
|
||
|
||
masks = masks[keep_indices]
|
||
scores = scores[keep_indices]
|
||
boxes = boxes[keep_indices]
|
||
print(f"[INFO] 发现 {len(masks)} 个目标 (Text: {args.text})")
|
||
|
||
# --- 4. 掩码处理与合并 (核心修改区域) ---
|
||
img_h, img_w = orig_cv2_image.shape[:2]
|
||
|
||
# 【核心修改点 1】创建一个全黑的单通道“总掩码”,用于累积所有检测对象的区域
|
||
total_mask = np.zeros((img_h, img_w), dtype=np.uint8)
|
||
|
||
# 循环处理每个检测到的目标 (保持原有的形态学处理逻辑不变)
|
||
for i in range(len(masks)):
|
||
m = masks[i]
|
||
box = boxes[i]
|
||
|
||
# --- A. 基础预处理 (保持不变) ---
|
||
if m.ndim > 2: m = m[0]
|
||
# Resize 到原图大小
|
||
if m.shape != (img_h, img_w):
|
||
m = cv2.resize(m.astype(np.uint8), (img_w, img_h), interpolation=cv2.INTER_NEAREST)
|
||
m = m.astype(np.uint8) # 确保是 uint8 (0, 1)
|
||
|
||
# --- B. 后处理流程 (保持不变) ---
|
||
# 1. 基础膨胀 (全局)
|
||
if args.dilate > 0:
|
||
k_size = args.dilate if args.dilate % 2 == 1 else args.dilate + 1
|
||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_size, k_size))
|
||
m = cv2.dilate(m, kernel, iterations=1)
|
||
|
||
# 2. 矩形区域填充 (Rect Fill) - 含 Padding 逻辑
|
||
if args.rect_fill:
|
||
x1, y1, x2, y2 = map(int, box)
|
||
# 应用检测框范围调整 (Padding)
|
||
pad = args.bbox_pad
|
||
x1, y1 = max(0, x1 - pad), max(0, y1 - pad)
|
||
x2, y2 = min(img_w, x2 + pad), min(img_h, y2 + pad)
|
||
|
||
if x2 > x1 and y2 > y1:
|
||
roi = m[y1:y2, x1:x2]
|
||
k_rect = args.rect_kernel
|
||
kernel_rect = cv2.getStructuringElement(cv2.MORPH_RECT, (k_rect, k_rect))
|
||
roi_closed = cv2.morphologyEx(roi, cv2.MORPH_CLOSE, kernel_rect)
|
||
m[y1:y2, x1:x2] = roi_closed
|
||
|
||
# 3. 孔洞填充 (Hole Filling)
|
||
if args.fill_holes:
|
||
contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||
m_filled = np.zeros_like(m)
|
||
cv2.drawContours(m_filled, contours, -1, 1, thickness=cv2.FILLED)
|
||
m = m_filled
|
||
|
||
# 【核心修改点 2】不再绘制彩色图层和边框,而是将处理好的单个掩码合并到总掩码中
|
||
# 使用按位或 (OR) 操作:只要任意一个目标的 Mask 在某个像素是白的,总 Mask 在这里就是白的
|
||
total_mask = cv2.bitwise_or(total_mask, m)
|
||
|
||
# --- 5. 生成最终图像 (Blackout 效果) ---
|
||
# 【核心修改点 3】利用总掩码,将原图中掩码为黑色(0)的区域涂黑,掩码为白色(255)的区域保留原像素
|
||
# 确保 total_mask 是 0 或 255
|
||
total_mask_255 = (total_mask * 255).astype(np.uint8)
|
||
final_image = cv2.bitwise_and(orig_cv2_image, orig_cv2_image, mask=total_mask_255)
|
||
|
||
# --- 6. 保存 ---
|
||
tags = []
|
||
# 添加一个标记表示是黑底抠图结果
|
||
tags.append("blackout")
|
||
if args.dilate > 0: tags.append(f"d{args.dilate}")
|
||
if args.rect_fill: tags.append(f"rect{args.rect_kernel}")
|
||
if args.bbox_pad != 0: tags.append(f"pad{args.bbox_pad}")
|
||
if args.fill_holes: tags.append("filled")
|
||
|
||
tag_str = "_" + "_".join(tags) if tags else ""
|
||
name_no_ext = os.path.splitext(filename)[0]
|
||
# 保存为 jpg 即可,因为背景是纯黑不是透明
|
||
save_filename = f"{name_no_ext}_{args.text.replace(' ', '_')}{tag_str}.jpg"
|
||
save_path = os.path.join(output_dir, save_filename)
|
||
|
||
cv2.imwrite(save_path, final_image)
|
||
print(f"[SUCCESS] 结果已保存至: {save_path}")
|
||
|
||
def main():
|
||
# (主函数的参数解析和加载逻辑完全保持不变)
|
||
parser = argparse.ArgumentParser(description="SAM 3 自动抠图脚本 (黑底保留原色)")
|
||
parser.add_argument("--input", type=str, required=True, help="输入路径")
|
||
parser.add_argument("--output", type=str, default="/app/outputs", help="输出目录")
|
||
parser.add_argument("--model", type=str, default="/app/checkpoints/sam3.pt", help="模型路径")
|
||
parser.add_argument("--text", type=str, required=True, help="提示词")
|
||
parser.add_argument("--conf", type=float, default=0.2, help="置信度")
|
||
|
||
# 后处理参数
|
||
parser.add_argument("--dilate", type=int, default=0, help="全局基础膨胀 (px)")
|
||
parser.add_argument("--fill-holes", action="store_true", help="开启孔洞填充")
|
||
parser.add_argument("--rect-fill", action="store_true", help="开启矩形区域智能填充")
|
||
parser.add_argument("--rect-kernel", type=int, default=20, help="矩形填充的核大小")
|
||
parser.add_argument("--bbox-pad", type=int, default=0, help="检测框范围调整 (px)")
|
||
|
||
args = parser.parse_args()
|
||
|
||
os.makedirs(args.output, exist_ok=True)
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
print(f"--- 运行配置 (黑底模式) ---")
|
||
print(f"输入: {args.input}")
|
||
print(f"模式: {'矩形填充' if args.rect_fill else '普通'}")
|
||
if args.rect_fill:
|
||
print(f"矩形核: {args.rect_kernel} px")
|
||
print(f"检测框Padding: {args.bbox_pad} px")
|
||
print(f"----------------")
|
||
|
||
image_list = []
|
||
if os.path.isdir(args.input):
|
||
for root, _, files in os.walk(args.input):
|
||
for file in files:
|
||
if os.path.splitext(file)[1].lower() in IMG_EXTENSIONS:
|
||
image_list.append(os.path.join(root, file))
|
||
elif os.path.isfile(args.input):
|
||
image_list.append(args.input)
|
||
|
||
if not image_list:
|
||
print("[Error] 未找到图片")
|
||
return
|
||
|
||
print(f"[INFO] 加载模型: {args.model} ...")
|
||
try:
|
||
sam3_model = build_sam3_image_model(
|
||
checkpoint_path=args.model,
|
||
load_from_HF=False,
|
||
device=device
|
||
)
|
||
processor = Sam3Processor(sam3_model)
|
||
except Exception as e:
|
||
print(f"[Fatal] 模型加载失败: {e}")
|
||
return
|
||
|
||
total = len(image_list)
|
||
for idx, img_path in enumerate(image_list):
|
||
print(f"\n--- [{idx+1}/{total}] ---")
|
||
process_one_image(processor, img_path, args, args.output)
|
||
|
||
print("\n[DONE] 完成。")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|