From 44b339b2534b9d493129a884539aaa2236829898 Mon Sep 17 00:00:00 2001 From: linsan Date: Wed, 4 Feb 2026 11:05:11 +0800 Subject: [PATCH] Add sam3 project code --- .gitignore | 7 ++ Dockerfile | 75 +++++++++++++++++ app.py | 198 ++++++++++++++++++++++++++++++++++++++++++++ docker-compose.yaml | 33 ++++++++ 4 files changed, 313 insertions(+) create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 app.py create mode 100644 docker-compose.yaml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bcec06c --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +checkpoints/ +outputs/ +results/ +data/ +__pycache__/ +*.pyc +.DS_Store diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..f36badb --- /dev/null +++ b/Dockerfile @@ -0,0 +1,75 @@ +# 1. 基础镜像:使用你验证过的 CUDA 12.4 开发版 +FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 + +# 环境变量 +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 +ENV DEBIAN_FRONTEND=noninteractive +# 针对 RTX 50 系列的架构优化 (Blackwell = sm_100/sm_120,这里涵盖所有新架构) +ENV TORCH_CUDA_ARCH_LIST="9.0;10.0+PTX" + +# 2. 安装系统工具 (增加了 ffmpeg 用于视频处理) +RUN apt-get update && apt-get install -y \ + python3.10 \ + python3-pip \ + git \ + wget \ + ffmpeg \ + libgl1 \ + libglib2.0-0 \ + libsm6 \ + libxext6 \ + build-essential \ + ninja-build \ + && rm -rf /var/lib/apt/lists/* + +# 建立 python 软链接,方便直接用 python 命令 +RUN ln -s /usr/bin/python3.10 /usr/bin/python + +WORKDIR /app + +# 3. 升级 pip (使用清华源加速) +RUN python3 -m pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple --default-timeout=100 + +# 4. 🔥 核心修复:安装适配 RTX 50 系列的 PyTorch (cu128) +# 这是解决 "sm_120 is not compatible" 的关键一步 +RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 --default-timeout=100 + +# 5. 安装 SAM3 依赖 (包含之前报错缺少的 decord, pycocotools) +# 移除了 transformers 版本限制,使用最新版适配 SAM3 +RUN pip install \ + opencv-python-headless \ + matplotlib \ + jupyter \ + jupyterlab \ + ipympl \ + pyyaml \ + tqdm \ + hydra-core \ + iopath \ + pillow \ + networkx \ + scipy \ + pandas \ + timm \ + einops \ + transformers \ + tokenizers \ + decord \ + pycocotools \ + -i https://pypi.tuna.tsinghua.edu.cn/simple \ + --default-timeout=100 + +# 6. 拉取 SAM3 代码 +RUN git clone https://github.com/facebookresearch/sam3.git sam3_code + +# 7. 安装 SAM3 包 +WORKDIR /app/sam3_code +RUN pip install -e . + +# 8. 设置工作目录和入口 +WORKDIR /app +EXPOSE 8888 + +# 默认保持运行,方便进入终端调试 +CMD ["/bin/bash", "-c", "jupyter lab --ip=0.0.0.0 --port=8888 --allow-root --no-browser --NotebookApp.token='sam3'"] diff --git a/app.py b/app.py new file mode 100644 index 0000000..0a22086 --- /dev/null +++ b/app.py @@ -0,0 +1,198 @@ +import argparse +import torch +import cv2 +import os +import numpy as np +from PIL import Image +from sam3 import build_sam3_image_model +from sam3.model.sam3_image_processor import Sam3Processor + +# 支持的图片扩展名 +IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff'} + +def process_one_image(processor, img_path, args, output_dir): + filename = os.path.basename(img_path) + print(f"\n[INFO] 正在处理: {filename} ...") + + # --- 1. 读取图片 --- + try: + # 读取原始图片用于推理 + pil_image = Image.open(img_path).convert("RGB") + # 读取 OpenCV 格式用于后续处理和最终输出背景 + orig_cv2_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) + except Exception as e: + print(f"[Error] 无法读取图片 {filename}: {e}") + return + + # --- 2. 推理 (逻辑不变) --- + try: + inference_state = processor.set_image(pil_image) + results = processor.set_text_prompt( + state=inference_state, + prompt=args.text + ) + except Exception as e: + print(f"[Error] 推理报错 ({filename}): {e}") + return + + # --- 3. 结果过滤 (逻辑不变) --- + masks = results.get("masks") + scores = results.get("scores") + boxes = results.get("boxes") + + if masks is None: + print(f"[WARN] {filename} 未检测到任何目标。") + return + + if isinstance(masks, torch.Tensor): masks = masks.cpu().numpy() + if isinstance(scores, torch.Tensor): scores = scores.cpu().numpy() + if isinstance(boxes, torch.Tensor): boxes = boxes.cpu().numpy() + + keep_indices = np.where(scores > args.conf)[0] + if len(keep_indices) == 0: + print(f"[WARN] {filename} 没有目标超过阈值 {args.conf}。") + return + + masks = masks[keep_indices] + scores = scores[keep_indices] + boxes = boxes[keep_indices] + print(f"[INFO] 发现 {len(masks)} 个目标 (Text: {args.text})") + + # --- 4. 掩码处理与合并 (核心修改区域) --- + img_h, img_w = orig_cv2_image.shape[:2] + + # 【核心修改点 1】创建一个全黑的单通道“总掩码”,用于累积所有检测对象的区域 + total_mask = np.zeros((img_h, img_w), dtype=np.uint8) + + # 循环处理每个检测到的目标 (保持原有的形态学处理逻辑不变) + for i in range(len(masks)): + m = masks[i] + box = boxes[i] + + # --- A. 基础预处理 (保持不变) --- + if m.ndim > 2: m = m[0] + # Resize 到原图大小 + if m.shape != (img_h, img_w): + m = cv2.resize(m.astype(np.uint8), (img_w, img_h), interpolation=cv2.INTER_NEAREST) + m = m.astype(np.uint8) # 确保是 uint8 (0, 1) + + # --- B. 后处理流程 (保持不变) --- + # 1. 基础膨胀 (全局) + if args.dilate > 0: + k_size = args.dilate if args.dilate % 2 == 1 else args.dilate + 1 + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_size, k_size)) + m = cv2.dilate(m, kernel, iterations=1) + + # 2. 矩形区域填充 (Rect Fill) - 含 Padding 逻辑 + if args.rect_fill: + x1, y1, x2, y2 = map(int, box) + # 应用检测框范围调整 (Padding) + pad = args.bbox_pad + x1, y1 = max(0, x1 - pad), max(0, y1 - pad) + x2, y2 = min(img_w, x2 + pad), min(img_h, y2 + pad) + + if x2 > x1 and y2 > y1: + roi = m[y1:y2, x1:x2] + k_rect = args.rect_kernel + kernel_rect = cv2.getStructuringElement(cv2.MORPH_RECT, (k_rect, k_rect)) + roi_closed = cv2.morphologyEx(roi, cv2.MORPH_CLOSE, kernel_rect) + m[y1:y2, x1:x2] = roi_closed + + # 3. 孔洞填充 (Hole Filling) + if args.fill_holes: + contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + m_filled = np.zeros_like(m) + cv2.drawContours(m_filled, contours, -1, 1, thickness=cv2.FILLED) + m = m_filled + + # 【核心修改点 2】不再绘制彩色图层和边框,而是将处理好的单个掩码合并到总掩码中 + # 使用按位或 (OR) 操作:只要任意一个目标的 Mask 在某个像素是白的,总 Mask 在这里就是白的 + total_mask = cv2.bitwise_or(total_mask, m) + + # --- 5. 生成最终图像 (Blackout 效果) --- + # 【核心修改点 3】利用总掩码,将原图中掩码为黑色(0)的区域涂黑,掩码为白色(255)的区域保留原像素 + # 确保 total_mask 是 0 或 255 + total_mask_255 = (total_mask * 255).astype(np.uint8) + final_image = cv2.bitwise_and(orig_cv2_image, orig_cv2_image, mask=total_mask_255) + + # --- 6. 保存 --- + tags = [] + # 添加一个标记表示是黑底抠图结果 + tags.append("blackout") + if args.dilate > 0: tags.append(f"d{args.dilate}") + if args.rect_fill: tags.append(f"rect{args.rect_kernel}") + if args.bbox_pad != 0: tags.append(f"pad{args.bbox_pad}") + if args.fill_holes: tags.append("filled") + + tag_str = "_" + "_".join(tags) if tags else "" + name_no_ext = os.path.splitext(filename)[0] + # 保存为 jpg 即可,因为背景是纯黑不是透明 + save_filename = f"{name_no_ext}_{args.text.replace(' ', '_')}{tag_str}.jpg" + save_path = os.path.join(output_dir, save_filename) + + cv2.imwrite(save_path, final_image) + print(f"[SUCCESS] 结果已保存至: {save_path}") + +def main(): + # (主函数的参数解析和加载逻辑完全保持不变) + parser = argparse.ArgumentParser(description="SAM 3 自动抠图脚本 (黑底保留原色)") + parser.add_argument("--input", type=str, required=True, help="输入路径") + parser.add_argument("--output", type=str, default="/app/outputs", help="输出目录") + parser.add_argument("--model", type=str, default="/app/checkpoints/sam3.pt", help="模型路径") + parser.add_argument("--text", type=str, required=True, help="提示词") + parser.add_argument("--conf", type=float, default=0.2, help="置信度") + + # 后处理参数 + parser.add_argument("--dilate", type=int, default=0, help="全局基础膨胀 (px)") + parser.add_argument("--fill-holes", action="store_true", help="开启孔洞填充") + parser.add_argument("--rect-fill", action="store_true", help="开启矩形区域智能填充") + parser.add_argument("--rect-kernel", type=int, default=20, help="矩形填充的核大小") + parser.add_argument("--bbox-pad", type=int, default=0, help="检测框范围调整 (px)") + + args = parser.parse_args() + + os.makedirs(args.output, exist_ok=True) + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"--- 运行配置 (黑底模式) ---") + print(f"输入: {args.input}") + print(f"模式: {'矩形填充' if args.rect_fill else '普通'}") + if args.rect_fill: + print(f"矩形核: {args.rect_kernel} px") + print(f"检测框Padding: {args.bbox_pad} px") + print(f"----------------") + + image_list = [] + if os.path.isdir(args.input): + for root, _, files in os.walk(args.input): + for file in files: + if os.path.splitext(file)[1].lower() in IMG_EXTENSIONS: + image_list.append(os.path.join(root, file)) + elif os.path.isfile(args.input): + image_list.append(args.input) + + if not image_list: + print("[Error] 未找到图片") + return + + print(f"[INFO] 加载模型: {args.model} ...") + try: + sam3_model = build_sam3_image_model( + checkpoint_path=args.model, + load_from_HF=False, + device=device + ) + processor = Sam3Processor(sam3_model) + except Exception as e: + print(f"[Fatal] 模型加载失败: {e}") + return + + total = len(image_list) + for idx, img_path in enumerate(image_list): + print(f"\n--- [{idx+1}/{total}] ---") + process_one_image(processor, img_path, args, args.output) + + print("\n[DONE] 完成。") + +if __name__ == "__main__": + main() + diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..a2c8ac4 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,33 @@ +services: + sam3: + build: . + container_name: sam3_fixed + # 启用 GPU + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + + # 共享内存优化 + shm_size: '16gb' + + volumes: + # 👇 这里是最关键的修正 👇 + - ./checkpoints:/app/checkpoints # 确保容器能读到你刚刚救出来的 .pt 文件 + - ./data:/app/data # 数据映射 + - ./outputs:/app/outputs # 输出结果 + + environment: + - NVIDIA_VISIBLE_DEVICES=all + - NVIDIA_DRIVER_CAPABILITIES=compute,utility,video + + ports: + - "7860:7860" # Gradio 端口 + - "8888:8888" # Jupyter 端口 (保留着也没事) + + stdin_open: true + tty: true + restart: unless-stopped