Add sam3 project code

This commit is contained in:
linsan 2026-02-04 11:05:11 +08:00
commit 44b339b253
4 changed files with 313 additions and 0 deletions

7
.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
checkpoints/
outputs/
results/
data/
__pycache__/
*.pyc
.DS_Store

75
Dockerfile Normal file
View File

@ -0,0 +1,75 @@
# 1. 基础镜像:使用你验证过的 CUDA 12.4 开发版
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
# 环境变量
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
ENV DEBIAN_FRONTEND=noninteractive
# 针对 RTX 50 系列的架构优化 (Blackwell = sm_100/sm_120这里涵盖所有新架构)
ENV TORCH_CUDA_ARCH_LIST="9.0;10.0+PTX"
# 2. 安装系统工具 (增加了 ffmpeg 用于视频处理)
RUN apt-get update && apt-get install -y \
python3.10 \
python3-pip \
git \
wget \
ffmpeg \
libgl1 \
libglib2.0-0 \
libsm6 \
libxext6 \
build-essential \
ninja-build \
&& rm -rf /var/lib/apt/lists/*
# 建立 python 软链接,方便直接用 python 命令
RUN ln -s /usr/bin/python3.10 /usr/bin/python
WORKDIR /app
# 3. 升级 pip (使用清华源加速)
RUN python3 -m pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple --default-timeout=100
# 4. 🔥 核心修复:安装适配 RTX 50 系列的 PyTorch (cu128)
# 这是解决 "sm_120 is not compatible" 的关键一步
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 --default-timeout=100
# 5. 安装 SAM3 依赖 (包含之前报错缺少的 decord, pycocotools)
# 移除了 transformers 版本限制,使用最新版适配 SAM3
RUN pip install \
opencv-python-headless \
matplotlib \
jupyter \
jupyterlab \
ipympl \
pyyaml \
tqdm \
hydra-core \
iopath \
pillow \
networkx \
scipy \
pandas \
timm \
einops \
transformers \
tokenizers \
decord \
pycocotools \
-i https://pypi.tuna.tsinghua.edu.cn/simple \
--default-timeout=100
# 6. 拉取 SAM3 代码
RUN git clone https://github.com/facebookresearch/sam3.git sam3_code
# 7. 安装 SAM3 包
WORKDIR /app/sam3_code
RUN pip install -e .
# 8. 设置工作目录和入口
WORKDIR /app
EXPOSE 8888
# 默认保持运行,方便进入终端调试
CMD ["/bin/bash", "-c", "jupyter lab --ip=0.0.0.0 --port=8888 --allow-root --no-browser --NotebookApp.token='sam3'"]

198
app.py Normal file
View File

@ -0,0 +1,198 @@
import argparse
import torch
import cv2
import os
import numpy as np
from PIL import Image
from sam3 import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# 支持的图片扩展名
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff'}
def process_one_image(processor, img_path, args, output_dir):
filename = os.path.basename(img_path)
print(f"\n[INFO] 正在处理: {filename} ...")
# --- 1. 读取图片 ---
try:
# 读取原始图片用于推理
pil_image = Image.open(img_path).convert("RGB")
# 读取 OpenCV 格式用于后续处理和最终输出背景
orig_cv2_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
except Exception as e:
print(f"[Error] 无法读取图片 {filename}: {e}")
return
# --- 2. 推理 (逻辑不变) ---
try:
inference_state = processor.set_image(pil_image)
results = processor.set_text_prompt(
state=inference_state,
prompt=args.text
)
except Exception as e:
print(f"[Error] 推理报错 ({filename}): {e}")
return
# --- 3. 结果过滤 (逻辑不变) ---
masks = results.get("masks")
scores = results.get("scores")
boxes = results.get("boxes")
if masks is None:
print(f"[WARN] {filename} 未检测到任何目标。")
return
if isinstance(masks, torch.Tensor): masks = masks.cpu().numpy()
if isinstance(scores, torch.Tensor): scores = scores.cpu().numpy()
if isinstance(boxes, torch.Tensor): boxes = boxes.cpu().numpy()
keep_indices = np.where(scores > args.conf)[0]
if len(keep_indices) == 0:
print(f"[WARN] {filename} 没有目标超过阈值 {args.conf}")
return
masks = masks[keep_indices]
scores = scores[keep_indices]
boxes = boxes[keep_indices]
print(f"[INFO] 发现 {len(masks)} 个目标 (Text: {args.text})")
# --- 4. 掩码处理与合并 (核心修改区域) ---
img_h, img_w = orig_cv2_image.shape[:2]
# 【核心修改点 1】创建一个全黑的单通道“总掩码”用于累积所有检测对象的区域
total_mask = np.zeros((img_h, img_w), dtype=np.uint8)
# 循环处理每个检测到的目标 (保持原有的形态学处理逻辑不变)
for i in range(len(masks)):
m = masks[i]
box = boxes[i]
# --- A. 基础预处理 (保持不变) ---
if m.ndim > 2: m = m[0]
# Resize 到原图大小
if m.shape != (img_h, img_w):
m = cv2.resize(m.astype(np.uint8), (img_w, img_h), interpolation=cv2.INTER_NEAREST)
m = m.astype(np.uint8) # 确保是 uint8 (0, 1)
# --- B. 后处理流程 (保持不变) ---
# 1. 基础膨胀 (全局)
if args.dilate > 0:
k_size = args.dilate if args.dilate % 2 == 1 else args.dilate + 1
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_size, k_size))
m = cv2.dilate(m, kernel, iterations=1)
# 2. 矩形区域填充 (Rect Fill) - 含 Padding 逻辑
if args.rect_fill:
x1, y1, x2, y2 = map(int, box)
# 应用检测框范围调整 (Padding)
pad = args.bbox_pad
x1, y1 = max(0, x1 - pad), max(0, y1 - pad)
x2, y2 = min(img_w, x2 + pad), min(img_h, y2 + pad)
if x2 > x1 and y2 > y1:
roi = m[y1:y2, x1:x2]
k_rect = args.rect_kernel
kernel_rect = cv2.getStructuringElement(cv2.MORPH_RECT, (k_rect, k_rect))
roi_closed = cv2.morphologyEx(roi, cv2.MORPH_CLOSE, kernel_rect)
m[y1:y2, x1:x2] = roi_closed
# 3. 孔洞填充 (Hole Filling)
if args.fill_holes:
contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
m_filled = np.zeros_like(m)
cv2.drawContours(m_filled, contours, -1, 1, thickness=cv2.FILLED)
m = m_filled
# 【核心修改点 2】不再绘制彩色图层和边框而是将处理好的单个掩码合并到总掩码中
# 使用按位或 (OR) 操作:只要任意一个目标的 Mask 在某个像素是白的,总 Mask 在这里就是白的
total_mask = cv2.bitwise_or(total_mask, m)
# --- 5. 生成最终图像 (Blackout 效果) ---
# 【核心修改点 3】利用总掩码将原图中掩码为黑色(0)的区域涂黑,掩码为白色(255)的区域保留原像素
# 确保 total_mask 是 0 或 255
total_mask_255 = (total_mask * 255).astype(np.uint8)
final_image = cv2.bitwise_and(orig_cv2_image, orig_cv2_image, mask=total_mask_255)
# --- 6. 保存 ---
tags = []
# 添加一个标记表示是黑底抠图结果
tags.append("blackout")
if args.dilate > 0: tags.append(f"d{args.dilate}")
if args.rect_fill: tags.append(f"rect{args.rect_kernel}")
if args.bbox_pad != 0: tags.append(f"pad{args.bbox_pad}")
if args.fill_holes: tags.append("filled")
tag_str = "_" + "_".join(tags) if tags else ""
name_no_ext = os.path.splitext(filename)[0]
# 保存为 jpg 即可,因为背景是纯黑不是透明
save_filename = f"{name_no_ext}_{args.text.replace(' ', '_')}{tag_str}.jpg"
save_path = os.path.join(output_dir, save_filename)
cv2.imwrite(save_path, final_image)
print(f"[SUCCESS] 结果已保存至: {save_path}")
def main():
# (主函数的参数解析和加载逻辑完全保持不变)
parser = argparse.ArgumentParser(description="SAM 3 自动抠图脚本 (黑底保留原色)")
parser.add_argument("--input", type=str, required=True, help="输入路径")
parser.add_argument("--output", type=str, default="/app/outputs", help="输出目录")
parser.add_argument("--model", type=str, default="/app/checkpoints/sam3.pt", help="模型路径")
parser.add_argument("--text", type=str, required=True, help="提示词")
parser.add_argument("--conf", type=float, default=0.2, help="置信度")
# 后处理参数
parser.add_argument("--dilate", type=int, default=0, help="全局基础膨胀 (px)")
parser.add_argument("--fill-holes", action="store_true", help="开启孔洞填充")
parser.add_argument("--rect-fill", action="store_true", help="开启矩形区域智能填充")
parser.add_argument("--rect-kernel", type=int, default=20, help="矩形填充的核大小")
parser.add_argument("--bbox-pad", type=int, default=0, help="检测框范围调整 (px)")
args = parser.parse_args()
os.makedirs(args.output, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"--- 运行配置 (黑底模式) ---")
print(f"输入: {args.input}")
print(f"模式: {'矩形填充' if args.rect_fill else '普通'}")
if args.rect_fill:
print(f"矩形核: {args.rect_kernel} px")
print(f"检测框Padding: {args.bbox_pad} px")
print(f"----------------")
image_list = []
if os.path.isdir(args.input):
for root, _, files in os.walk(args.input):
for file in files:
if os.path.splitext(file)[1].lower() in IMG_EXTENSIONS:
image_list.append(os.path.join(root, file))
elif os.path.isfile(args.input):
image_list.append(args.input)
if not image_list:
print("[Error] 未找到图片")
return
print(f"[INFO] 加载模型: {args.model} ...")
try:
sam3_model = build_sam3_image_model(
checkpoint_path=args.model,
load_from_HF=False,
device=device
)
processor = Sam3Processor(sam3_model)
except Exception as e:
print(f"[Fatal] 模型加载失败: {e}")
return
total = len(image_list)
for idx, img_path in enumerate(image_list):
print(f"\n--- [{idx+1}/{total}] ---")
process_one_image(processor, img_path, args, args.output)
print("\n[DONE] 完成。")
if __name__ == "__main__":
main()

33
docker-compose.yaml Normal file
View File

@ -0,0 +1,33 @@
services:
sam3:
build: .
container_name: sam3_fixed
# 启用 GPU
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
# 共享内存优化
shm_size: '16gb'
volumes:
# 👇 这里是最关键的修正 👇
- ./checkpoints:/app/checkpoints # 确保容器能读到你刚刚救出来的 .pt 文件
- ./data:/app/data # 数据映射
- ./outputs:/app/outputs # 输出结果
environment:
- NVIDIA_VISIBLE_DEVICES=all
- NVIDIA_DRIVER_CAPABILITIES=compute,utility,video
ports:
- "7860:7860" # Gradio 端口
- "8888:8888" # Jupyter 端口 (保留着也没事)
stdin_open: true
tty: true
restart: unless-stopped