import argparse import torch import cv2 import os import numpy as np from PIL import Image from sam3 import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor # 支持的图片扩展名 IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff'} def process_one_image(processor, img_path, args, output_dir): filename = os.path.basename(img_path) print(f"\n[INFO] 正在处理: {filename} ...") # --- 1. 读取图片 --- try: # 读取原始图片用于推理 pil_image = Image.open(img_path).convert("RGB") # 读取 OpenCV 格式用于后续处理和最终输出背景 orig_cv2_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) except Exception as e: print(f"[Error] 无法读取图片 {filename}: {e}") return # --- 2. 推理 (逻辑不变) --- try: inference_state = processor.set_image(pil_image) results = processor.set_text_prompt( state=inference_state, prompt=args.text ) except Exception as e: print(f"[Error] 推理报错 ({filename}): {e}") return # --- 3. 结果过滤 (逻辑不变) --- masks = results.get("masks") scores = results.get("scores") boxes = results.get("boxes") if masks is None: print(f"[WARN] {filename} 未检测到任何目标。") return if isinstance(masks, torch.Tensor): masks = masks.cpu().numpy() if isinstance(scores, torch.Tensor): scores = scores.cpu().numpy() if isinstance(boxes, torch.Tensor): boxes = boxes.cpu().numpy() keep_indices = np.where(scores > args.conf)[0] if len(keep_indices) == 0: print(f"[WARN] {filename} 没有目标超过阈值 {args.conf}。") return masks = masks[keep_indices] scores = scores[keep_indices] boxes = boxes[keep_indices] print(f"[INFO] 发现 {len(masks)} 个目标 (Text: {args.text})") # --- 4. 掩码处理与合并 (核心修改区域) --- img_h, img_w = orig_cv2_image.shape[:2] # 【核心修改点 1】创建一个全黑的单通道“总掩码”,用于累积所有检测对象的区域 total_mask = np.zeros((img_h, img_w), dtype=np.uint8) # 循环处理每个检测到的目标 (保持原有的形态学处理逻辑不变) for i in range(len(masks)): m = masks[i] box = boxes[i] # --- A. 基础预处理 (保持不变) --- if m.ndim > 2: m = m[0] # Resize 到原图大小 if m.shape != (img_h, img_w): m = cv2.resize(m.astype(np.uint8), (img_w, img_h), interpolation=cv2.INTER_NEAREST) m = m.astype(np.uint8) # 确保是 uint8 (0, 1) # --- B. 后处理流程 (保持不变) --- # 1. 基础膨胀 (全局) if args.dilate > 0: k_size = args.dilate if args.dilate % 2 == 1 else args.dilate + 1 kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_size, k_size)) m = cv2.dilate(m, kernel, iterations=1) # 2. 矩形区域填充 (Rect Fill) - 含 Padding 逻辑 if args.rect_fill: x1, y1, x2, y2 = map(int, box) # 应用检测框范围调整 (Padding) pad = args.bbox_pad x1, y1 = max(0, x1 - pad), max(0, y1 - pad) x2, y2 = min(img_w, x2 + pad), min(img_h, y2 + pad) if x2 > x1 and y2 > y1: roi = m[y1:y2, x1:x2] k_rect = args.rect_kernel kernel_rect = cv2.getStructuringElement(cv2.MORPH_RECT, (k_rect, k_rect)) roi_closed = cv2.morphologyEx(roi, cv2.MORPH_CLOSE, kernel_rect) m[y1:y2, x1:x2] = roi_closed # 3. 孔洞填充 (Hole Filling) if args.fill_holes: contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) m_filled = np.zeros_like(m) cv2.drawContours(m_filled, contours, -1, 1, thickness=cv2.FILLED) m = m_filled # 【核心修改点 2】不再绘制彩色图层和边框,而是将处理好的单个掩码合并到总掩码中 # 使用按位或 (OR) 操作:只要任意一个目标的 Mask 在某个像素是白的,总 Mask 在这里就是白的 total_mask = cv2.bitwise_or(total_mask, m) # --- 5. 生成最终图像 (Blackout 效果) --- # 【核心修改点 3】利用总掩码,将原图中掩码为黑色(0)的区域涂黑,掩码为白色(255)的区域保留原像素 # 确保 total_mask 是 0 或 255 total_mask_255 = (total_mask * 255).astype(np.uint8) final_image = cv2.bitwise_and(orig_cv2_image, orig_cv2_image, mask=total_mask_255) # --- 6. 保存 --- tags = [] # 添加一个标记表示是黑底抠图结果 tags.append("blackout") if args.dilate > 0: tags.append(f"d{args.dilate}") if args.rect_fill: tags.append(f"rect{args.rect_kernel}") if args.bbox_pad != 0: tags.append(f"pad{args.bbox_pad}") if args.fill_holes: tags.append("filled") tag_str = "_" + "_".join(tags) if tags else "" name_no_ext = os.path.splitext(filename)[0] # 保存为 jpg 即可,因为背景是纯黑不是透明 save_filename = f"{name_no_ext}_{args.text.replace(' ', '_')}{tag_str}.jpg" save_path = os.path.join(output_dir, save_filename) cv2.imwrite(save_path, final_image) print(f"[SUCCESS] 结果已保存至: {save_path}") def main(): # (主函数的参数解析和加载逻辑完全保持不变) parser = argparse.ArgumentParser(description="SAM 3 自动抠图脚本 (黑底保留原色)") parser.add_argument("--input", type=str, required=True, help="输入路径") parser.add_argument("--output", type=str, default="/app/outputs", help="输出目录") parser.add_argument("--model", type=str, default="/app/checkpoints/sam3.pt", help="模型路径") parser.add_argument("--text", type=str, required=True, help="提示词") parser.add_argument("--conf", type=float, default=0.2, help="置信度") # 后处理参数 parser.add_argument("--dilate", type=int, default=0, help="全局基础膨胀 (px)") parser.add_argument("--fill-holes", action="store_true", help="开启孔洞填充") parser.add_argument("--rect-fill", action="store_true", help="开启矩形区域智能填充") parser.add_argument("--rect-kernel", type=int, default=20, help="矩形填充的核大小") parser.add_argument("--bbox-pad", type=int, default=0, help="检测框范围调整 (px)") args = parser.parse_args() os.makedirs(args.output, exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"--- 运行配置 (黑底模式) ---") print(f"输入: {args.input}") print(f"模式: {'矩形填充' if args.rect_fill else '普通'}") if args.rect_fill: print(f"矩形核: {args.rect_kernel} px") print(f"检测框Padding: {args.bbox_pad} px") print(f"----------------") image_list = [] if os.path.isdir(args.input): for root, _, files in os.walk(args.input): for file in files: if os.path.splitext(file)[1].lower() in IMG_EXTENSIONS: image_list.append(os.path.join(root, file)) elif os.path.isfile(args.input): image_list.append(args.input) if not image_list: print("[Error] 未找到图片") return print(f"[INFO] 加载模型: {args.model} ...") try: sam3_model = build_sam3_image_model( checkpoint_path=args.model, load_from_HF=False, device=device ) processor = Sam3Processor(sam3_model) except Exception as e: print(f"[Fatal] 模型加载失败: {e}") return total = len(image_list) for idx, img_path in enumerate(image_list): print(f"\n--- [{idx+1}/{total}] ---") process_one_image(processor, img_path, args, args.output) print("\n[DONE] 完成。") if __name__ == "__main__": main()