puruan/app.py
2026-02-04 11:05:11 +08:00

199 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import torch
import cv2
import os
import numpy as np
from PIL import Image
from sam3 import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# 支持的图片扩展名
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff'}
def process_one_image(processor, img_path, args, output_dir):
filename = os.path.basename(img_path)
print(f"\n[INFO] 正在处理: {filename} ...")
# --- 1. 读取图片 ---
try:
# 读取原始图片用于推理
pil_image = Image.open(img_path).convert("RGB")
# 读取 OpenCV 格式用于后续处理和最终输出背景
orig_cv2_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
except Exception as e:
print(f"[Error] 无法读取图片 {filename}: {e}")
return
# --- 2. 推理 (逻辑不变) ---
try:
inference_state = processor.set_image(pil_image)
results = processor.set_text_prompt(
state=inference_state,
prompt=args.text
)
except Exception as e:
print(f"[Error] 推理报错 ({filename}): {e}")
return
# --- 3. 结果过滤 (逻辑不变) ---
masks = results.get("masks")
scores = results.get("scores")
boxes = results.get("boxes")
if masks is None:
print(f"[WARN] {filename} 未检测到任何目标。")
return
if isinstance(masks, torch.Tensor): masks = masks.cpu().numpy()
if isinstance(scores, torch.Tensor): scores = scores.cpu().numpy()
if isinstance(boxes, torch.Tensor): boxes = boxes.cpu().numpy()
keep_indices = np.where(scores > args.conf)[0]
if len(keep_indices) == 0:
print(f"[WARN] {filename} 没有目标超过阈值 {args.conf}")
return
masks = masks[keep_indices]
scores = scores[keep_indices]
boxes = boxes[keep_indices]
print(f"[INFO] 发现 {len(masks)} 个目标 (Text: {args.text})")
# --- 4. 掩码处理与合并 (核心修改区域) ---
img_h, img_w = orig_cv2_image.shape[:2]
# 【核心修改点 1】创建一个全黑的单通道“总掩码”用于累积所有检测对象的区域
total_mask = np.zeros((img_h, img_w), dtype=np.uint8)
# 循环处理每个检测到的目标 (保持原有的形态学处理逻辑不变)
for i in range(len(masks)):
m = masks[i]
box = boxes[i]
# --- A. 基础预处理 (保持不变) ---
if m.ndim > 2: m = m[0]
# Resize 到原图大小
if m.shape != (img_h, img_w):
m = cv2.resize(m.astype(np.uint8), (img_w, img_h), interpolation=cv2.INTER_NEAREST)
m = m.astype(np.uint8) # 确保是 uint8 (0, 1)
# --- B. 后处理流程 (保持不变) ---
# 1. 基础膨胀 (全局)
if args.dilate > 0:
k_size = args.dilate if args.dilate % 2 == 1 else args.dilate + 1
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_size, k_size))
m = cv2.dilate(m, kernel, iterations=1)
# 2. 矩形区域填充 (Rect Fill) - 含 Padding 逻辑
if args.rect_fill:
x1, y1, x2, y2 = map(int, box)
# 应用检测框范围调整 (Padding)
pad = args.bbox_pad
x1, y1 = max(0, x1 - pad), max(0, y1 - pad)
x2, y2 = min(img_w, x2 + pad), min(img_h, y2 + pad)
if x2 > x1 and y2 > y1:
roi = m[y1:y2, x1:x2]
k_rect = args.rect_kernel
kernel_rect = cv2.getStructuringElement(cv2.MORPH_RECT, (k_rect, k_rect))
roi_closed = cv2.morphologyEx(roi, cv2.MORPH_CLOSE, kernel_rect)
m[y1:y2, x1:x2] = roi_closed
# 3. 孔洞填充 (Hole Filling)
if args.fill_holes:
contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
m_filled = np.zeros_like(m)
cv2.drawContours(m_filled, contours, -1, 1, thickness=cv2.FILLED)
m = m_filled
# 【核心修改点 2】不再绘制彩色图层和边框而是将处理好的单个掩码合并到总掩码中
# 使用按位或 (OR) 操作:只要任意一个目标的 Mask 在某个像素是白的,总 Mask 在这里就是白的
total_mask = cv2.bitwise_or(total_mask, m)
# --- 5. 生成最终图像 (Blackout 效果) ---
# 【核心修改点 3】利用总掩码将原图中掩码为黑色(0)的区域涂黑,掩码为白色(255)的区域保留原像素
# 确保 total_mask 是 0 或 255
total_mask_255 = (total_mask * 255).astype(np.uint8)
final_image = cv2.bitwise_and(orig_cv2_image, orig_cv2_image, mask=total_mask_255)
# --- 6. 保存 ---
tags = []
# 添加一个标记表示是黑底抠图结果
tags.append("blackout")
if args.dilate > 0: tags.append(f"d{args.dilate}")
if args.rect_fill: tags.append(f"rect{args.rect_kernel}")
if args.bbox_pad != 0: tags.append(f"pad{args.bbox_pad}")
if args.fill_holes: tags.append("filled")
tag_str = "_" + "_".join(tags) if tags else ""
name_no_ext = os.path.splitext(filename)[0]
# 保存为 jpg 即可,因为背景是纯黑不是透明
save_filename = f"{name_no_ext}_{args.text.replace(' ', '_')}{tag_str}.jpg"
save_path = os.path.join(output_dir, save_filename)
cv2.imwrite(save_path, final_image)
print(f"[SUCCESS] 结果已保存至: {save_path}")
def main():
# (主函数的参数解析和加载逻辑完全保持不变)
parser = argparse.ArgumentParser(description="SAM 3 自动抠图脚本 (黑底保留原色)")
parser.add_argument("--input", type=str, required=True, help="输入路径")
parser.add_argument("--output", type=str, default="/app/outputs", help="输出目录")
parser.add_argument("--model", type=str, default="/app/checkpoints/sam3.pt", help="模型路径")
parser.add_argument("--text", type=str, required=True, help="提示词")
parser.add_argument("--conf", type=float, default=0.2, help="置信度")
# 后处理参数
parser.add_argument("--dilate", type=int, default=0, help="全局基础膨胀 (px)")
parser.add_argument("--fill-holes", action="store_true", help="开启孔洞填充")
parser.add_argument("--rect-fill", action="store_true", help="开启矩形区域智能填充")
parser.add_argument("--rect-kernel", type=int, default=20, help="矩形填充的核大小")
parser.add_argument("--bbox-pad", type=int, default=0, help="检测框范围调整 (px)")
args = parser.parse_args()
os.makedirs(args.output, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"--- 运行配置 (黑底模式) ---")
print(f"输入: {args.input}")
print(f"模式: {'矩形填充' if args.rect_fill else '普通'}")
if args.rect_fill:
print(f"矩形核: {args.rect_kernel} px")
print(f"检测框Padding: {args.bbox_pad} px")
print(f"----------------")
image_list = []
if os.path.isdir(args.input):
for root, _, files in os.walk(args.input):
for file in files:
if os.path.splitext(file)[1].lower() in IMG_EXTENSIONS:
image_list.append(os.path.join(root, file))
elif os.path.isfile(args.input):
image_list.append(args.input)
if not image_list:
print("[Error] 未找到图片")
return
print(f"[INFO] 加载模型: {args.model} ...")
try:
sam3_model = build_sam3_image_model(
checkpoint_path=args.model,
load_from_HF=False,
device=device
)
processor = Sam3Processor(sam3_model)
except Exception as e:
print(f"[Fatal] 模型加载失败: {e}")
return
total = len(image_list)
for idx, img_path in enumerate(image_list):
print(f"\n--- [{idx+1}/{total}] ---")
process_one_image(processor, img_path, args, args.output)
print("\n[DONE] 完成。")
if __name__ == "__main__":
main()