Compare commits
No commits in common. "main" and "sam3" have entirely different histories.
9
.gitignore
vendored
9
.gitignore
vendored
@ -1,8 +1,7 @@
|
|||||||
__pycache__/
|
checkpoints/
|
||||||
hf_cache/
|
outputs/
|
||||||
models/
|
results/
|
||||||
temp_uploads/
|
|
||||||
debug_raw_heatmap.png
|
|
||||||
data/
|
data/
|
||||||
|
__pycache__/
|
||||||
*.pyc
|
*.pyc
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|||||||
77
Dockerfile
77
Dockerfile
@ -1,48 +1,75 @@
|
|||||||
# 基础镜像:CUDA 12.4 (匹配 RTX 50 系)
|
# 1. 基础镜像:使用你验证过的 CUDA 12.4 开发版
|
||||||
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
|
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
|
||||||
|
|
||||||
# 环境变量
|
# 环境变量
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1
|
ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
ENV HF_HOME=/root/.cache/huggingface
|
# 针对 RTX 50 系列的架构优化 (Blackwell = sm_100/sm_120,这里涵盖所有新架构)
|
||||||
|
ENV TORCH_CUDA_ARCH_LIST="9.0;10.0+PTX"
|
||||||
|
|
||||||
WORKDIR /app
|
# 2. 安装系统工具 (增加了 ffmpeg 用于视频处理)
|
||||||
|
|
||||||
# 1. 系统依赖
|
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
python3.10 \
|
python3.10 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
git \
|
git \
|
||||||
wget \
|
wget \
|
||||||
vim \
|
ffmpeg \
|
||||||
libgl1 \
|
libgl1 \
|
||||||
libglib2.0-0 \
|
libglib2.0-0 \
|
||||||
|
libsm6 \
|
||||||
|
libxext6 \
|
||||||
|
build-essential \
|
||||||
|
ninja-build \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# 2. 设置 pip 清华源全局配置 (省去每次敲参数)
|
# 建立 python 软链接,方便直接用 python 命令
|
||||||
RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
RUN ln -s /usr/bin/python3.10 /usr/bin/python
|
||||||
|
|
||||||
# 3. 🔥 核心:RTX 5060 专用 PyTorch (cu128)
|
WORKDIR /app
|
||||||
# 必须显式指定 index-url 覆盖上面的清华源配置,因为 cu128 只有官方 nightly 有
|
|
||||||
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
|
|
||||||
|
|
||||||
# 4. 安装 DiffSim/Diffusion 核心库
|
# 3. 升级 pip (使用清华源加速)
|
||||||
# 包含 tqdm 用于显示进度条
|
RUN python3 -m pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple --default-timeout=100
|
||||||
|
|
||||||
|
# 4. 🔥 核心修复:安装适配 RTX 50 系列的 PyTorch (cu128)
|
||||||
|
# 这是解决 "sm_120 is not compatible" 的关键一步
|
||||||
|
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 --default-timeout=100
|
||||||
|
|
||||||
|
# 5. 安装 SAM3 依赖 (包含之前报错缺少的 decord, pycocotools)
|
||||||
|
# 移除了 transformers 版本限制,使用最新版适配 SAM3
|
||||||
RUN pip install \
|
RUN pip install \
|
||||||
"diffusers>=0.24.0" \
|
opencv-python-headless \
|
||||||
"transformers>=4.35.0" \
|
matplotlib \
|
||||||
accelerate \
|
jupyter \
|
||||||
|
jupyterlab \
|
||||||
|
ipympl \
|
||||||
|
pyyaml \
|
||||||
|
tqdm \
|
||||||
|
hydra-core \
|
||||||
|
iopath \
|
||||||
|
pillow \
|
||||||
|
networkx \
|
||||||
scipy \
|
scipy \
|
||||||
safetensors \
|
pandas \
|
||||||
opencv-python \
|
|
||||||
ftfy \
|
|
||||||
regex \
|
|
||||||
timm \
|
timm \
|
||||||
einops \
|
einops \
|
||||||
lpips \
|
transformers \
|
||||||
tqdm \
|
tokenizers \
|
||||||
matplotlib
|
decord \
|
||||||
|
pycocotools \
|
||||||
|
-i https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||||
|
--default-timeout=100
|
||||||
|
|
||||||
# 5. 默认指令:启动 bash,让你进入终端
|
# 6. 拉取 SAM3 代码
|
||||||
CMD ["/bin/bash"]
|
RUN git clone https://github.com/facebookresearch/sam3.git sam3_code
|
||||||
|
|
||||||
|
# 7. 安装 SAM3 包
|
||||||
|
WORKDIR /app/sam3_code
|
||||||
|
RUN pip install -e .
|
||||||
|
|
||||||
|
# 8. 设置工作目录和入口
|
||||||
|
WORKDIR /app
|
||||||
|
EXPOSE 8888
|
||||||
|
|
||||||
|
# 默认保持运行,方便进入终端调试
|
||||||
|
CMD ["/bin/bash", "-c", "jupyter lab --ip=0.0.0.0 --port=8888 --allow-root --no-browser --NotebookApp.token='sam3'"]
|
||||||
|
|||||||
379
README.md
379
README.md
@ -1,379 +0,0 @@
|
|||||||
结构图(如有误请纠正)
|
|
||||||
```mermaid
|
|
||||||
graph TD
|
|
||||||
%% 用户和网关
|
|
||||||
User((用户)) -->|"上传数据"| Gateway[API Gateway<br/>入口 + 调度器]
|
|
||||||
|
|
||||||
%% Gateway 调度逻辑
|
|
||||||
Gateway -->|"1. 存储原始数据"| Folder1
|
|
||||||
Gateway -->|"2. 创建任务记录"| SQLite[(SQLite<br/>任务状态数据库)]
|
|
||||||
|
|
||||||
%% 共享存储子图
|
|
||||||
subgraph Storage["Docker Shared Volume (/app/data)"]
|
|
||||||
Folder1["/input<br/>(100G 原始数据)"]
|
|
||||||
Folder2["/matting_output<br/>(抠图中间结果)"]
|
|
||||||
Folder3["/final_results<br/>(对比最终结果)"]
|
|
||||||
SQLite
|
|
||||||
end
|
|
||||||
|
|
||||||
%% 任务分发到 Matting
|
|
||||||
Gateway -->|"4. POST /matting {task_id}"| MattingAPI[统一 Matting API]
|
|
||||||
%% Matting 服务集群
|
|
||||||
MattingAPI -->|"5. 分发任务"| MattingCluster
|
|
||||||
subgraph MattingCluster["Matting Service 集群<br/>(共享 GPU)"]
|
|
||||||
M1[Matting Worker 1]
|
|
||||||
M2[Matting Worker 2]
|
|
||||||
M3[Matting Worker N]
|
|
||||||
end
|
|
||||||
|
|
||||||
%% Matting 处理流程
|
|
||||||
MattingCluster -->|"6. 读取分片"| Folder1
|
|
||||||
MattingCluster -->|"7. 更新状态<br/>(processing)"| SQLite
|
|
||||||
MattingCluster -->|"8. 写入结果"| Folder2
|
|
||||||
MattingCluster -->|"9. 更新状态<br/>(matting_done)"| SQLite
|
|
||||||
|
|
||||||
%% Gateway 轮询调度
|
|
||||||
%% SQLite -.->|"10. 轮询检测<br/>matting_done"| Gateway
|
|
||||||
M1 -->|"11. POST /compare {task_id}"| ComparisonAPI[统一 Comparison API]
|
|
||||||
M2 -->|"11. POST /compare {task_id}"| ComparisonAPI[统一 Comparison API]
|
|
||||||
M3 -->|"11. POST /compare {task_id}"| ComparisonAPI[统一 Comparison API]
|
|
||||||
%% Comparison 服务集群
|
|
||||||
ComparisonAPI -->|"12. 分发任务"| ComparisonCluster
|
|
||||||
subgraph ComparisonCluster["Comparison Service 集群<br/>(共享 GPU)"]
|
|
||||||
C1[Comparison Worker 1]
|
|
||||||
C2[Comparison Worker 2]
|
|
||||||
C3[Comparison Worker N]
|
|
||||||
end
|
|
||||||
|
|
||||||
%% Comparison 处理流程
|
|
||||||
ComparisonCluster -->|"13. 读取抠图结果"| Folder2
|
|
||||||
ComparisonCluster -->|"14. 更新状态<br/>(comparing)"| SQLite
|
|
||||||
ComparisonCluster -->|"15. 写入对比结果"| Folder3
|
|
||||||
ComparisonCluster -->|"16. 更新状态<br/>(completed)"| SQLite
|
|
||||||
|
|
||||||
%% 最终返回
|
|
||||||
SQLite -.->|"17. 检测 completed"| Gateway
|
|
||||||
Gateway -->|"18. 返回最终结果"| User
|
|
||||||
|
|
||||||
%% 样式美化 - 黑色背景优化
|
|
||||||
style Gateway fill:#1e88e5,stroke:#64b5f6,stroke-width:3px,color:#fff
|
|
||||||
style SQLite fill:#ab47bc,stroke:#ce93d8,stroke-width:2px,color:#fff
|
|
||||||
style MattingAPI fill:#26c6da,stroke:#4dd0e1,stroke-width:2px,color:#000
|
|
||||||
style ComparisonAPI fill:#ff7043,stroke:#ff8a65,stroke-width:2px,color:#fff
|
|
||||||
style MattingCluster fill:#0277bd,stroke:#0288d1,color:#fff
|
|
||||||
style ComparisonCluster fill:#d84315,stroke:#e64a19,color:#fff
|
|
||||||
style M1 fill:#0288d1,stroke:#4fc3f7,color:#fff
|
|
||||||
style M2 fill:#0288d1,stroke:#4fc3f7,color:#fff
|
|
||||||
style M3 fill:#0288d1,stroke:#4fc3f7,color:#fff
|
|
||||||
style C1 fill:#e64a19,stroke:#ff7043,color:#fff
|
|
||||||
style C2 fill:#e64a19,stroke:#ff7043,color:#fff
|
|
||||||
style C3 fill:#e64a19,stroke:#ff7043,color:#fff
|
|
||||||
style Folder1 fill:#424242,stroke:#9e9e9e,stroke-width:2px,color:#e0e0e0
|
|
||||||
style Folder2 fill:#424242,stroke:#9e9e9e,stroke-width:2px,color:#e0e0e0
|
|
||||||
style Folder3 fill:#424242,stroke:#9e9e9e,stroke-width:2px,color:#e0e0e0
|
|
||||||
style Storage fill:#1b5e20,stroke:#4caf50,stroke-width:2px,color:#fff
|
|
||||||
style User fill:#5e35b1,stroke:#9575cd,stroke-width:2px,color:#fff
|
|
||||||
```
|
|
||||||
|
|
||||||
序列图(如有误请纠正)
|
|
||||||
|
|
||||||
```mermaid
|
|
||||||
%%{init: {'theme':'dark'}}%%
|
|
||||||
sequenceDiagram
|
|
||||||
autonumber
|
|
||||||
participant User as 用户
|
|
||||||
participant GW as Gateway<br/>(入口+调度)
|
|
||||||
participant DB as SQLite<br/>(任务数据库)
|
|
||||||
participant FS as 共享存储<br/>(/app/data)
|
|
||||||
participant MAPI as Matting API
|
|
||||||
participant MW as Matting Worker
|
|
||||||
participant CAPI as Comparison API
|
|
||||||
participant CW as Comparison Worker
|
|
||||||
|
|
||||||
%% 上传阶段
|
|
||||||
rect rgb(30, 60, 90)
|
|
||||||
Note over User,GW: 阶段1: 数据上传与任务创建
|
|
||||||
User->>+GW: POST /upload (100GB 原始数据)
|
|
||||||
GW->>FS: 保存到 /input/raw_data.zip
|
|
||||||
GW->>GW: 数据切片 (chunk_001 ~ chunk_N)
|
|
||||||
GW->>FS: 保存切片到 /input/chunks/
|
|
||||||
loop 为每个切片创建任务
|
|
||||||
GW->>DB: INSERT task (task_id, status='pending')
|
|
||||||
end
|
|
||||||
GW-->>-User: 返回 job_id={uuid}
|
|
||||||
end
|
|
||||||
|
|
||||||
%% Gateway 主动调度 Matting
|
|
||||||
rect rgb(0, 60, 80)
|
|
||||||
Note over GW,MW: 阶段2: Gateway 调度 Matting 任务
|
|
||||||
loop 遍历所有待处理任务
|
|
||||||
GW->>DB: SELECT task WHERE status='pending' LIMIT 1
|
|
||||||
DB-->>GW: 返回 task_id='chunk_001'
|
|
||||||
GW->>DB: UPDATE status='dispatched_matting'
|
|
||||||
|
|
||||||
GW->>+MAPI: POST /matting<br/>{task_id: 'chunk_001', input_path: '/input/chunks/chunk_001.jpg'}
|
|
||||||
MAPI->>MAPI: 负载均衡选择空闲 Worker
|
|
||||||
MAPI->>+MW: 分发任务到 Matting Worker 1
|
|
||||||
|
|
||||||
MW->>DB: UPDATE status='processing', worker_id='matting-1', start_time=now()
|
|
||||||
MW->>FS: 读取 /input/chunks/chunk_001.jpg
|
|
||||||
MW->>MW: GPU 抠图处理 (使用共享显卡)
|
|
||||||
MW->>FS: 写入 /matting_output/chunk_001.png
|
|
||||||
MW->>DB: UPDATE status='matting_done', end_time=now()
|
|
||||||
MW-->>-MAPI: 返回 {status: 'success', output_path: '/matting_output/chunk_001.png'}
|
|
||||||
MAPI-->>-GW: 返回处理成功
|
|
||||||
|
|
||||||
Note over GW: 立即触发下一阶段
|
|
||||||
GW->>GW: 检测到 matting_done
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
%% Gateway 主动调度 Comparison
|
|
||||||
rect rgb(80, 40, 0)
|
|
||||||
Note over GW,CW: 阶段3: Gateway 调度 Comparison 任务
|
|
||||||
GW->>DB: SELECT task WHERE status='matting_done'
|
|
||||||
DB-->>GW: 返回 task_id='chunk_001'
|
|
||||||
GW->>DB: UPDATE status='dispatched_comparison'
|
|
||||||
|
|
||||||
GW->>+CAPI: POST /compare<br/>{task_id: 'chunk_001', original: '/input/chunks/chunk_001.jpg', matting: '/matting_output/chunk_001.png'}
|
|
||||||
CAPI->>CAPI: 负载均衡选择空闲 Worker
|
|
||||||
CAPI->>+CW: 分发任务到 Comparison Worker 1
|
|
||||||
|
|
||||||
CW->>DB: UPDATE status='comparing', worker_id='comparison-1', start_time=now()
|
|
||||||
CW->>FS: 读取 /input/chunks/chunk_001.jpg (原图)
|
|
||||||
CW->>FS: 读取 /matting_output/chunk_001.png (抠图结果)
|
|
||||||
CW->>CW: GPU 对比分析 (使用共享显卡)
|
|
||||||
CW->>FS: 写入 /final_results/chunk_001_compare.jpg
|
|
||||||
CW->>DB: UPDATE status='completed', end_time=now()
|
|
||||||
CW-->>-CAPI: 返回 {status: 'success', result_path: '/final_results/chunk_001_compare.jpg'}
|
|
||||||
CAPI-->>-GW: 返回处理成功
|
|
||||||
end
|
|
||||||
|
|
||||||
%% 并发处理多个任务
|
|
||||||
rect rgb(40, 40, 40)
|
|
||||||
Note over GW,CW: 并发处理 (Gateway 继续调度其他任务)
|
|
||||||
par Gateway 同时调度多个任务
|
|
||||||
GW->>MAPI: POST /matting {task_id: 'chunk_002'}
|
|
||||||
MAPI->>MW: Worker 2 处理
|
|
||||||
and
|
|
||||||
GW->>MAPI: POST /matting {task_id: 'chunk_003'}
|
|
||||||
MAPI->>MW: Worker 3 处理
|
|
||||||
and
|
|
||||||
GW->>CAPI: POST /compare {task_id: 'chunk_001'}
|
|
||||||
CAPI->>CW: Worker 1 处理
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
%% 进度查询
|
|
||||||
rect rgb(20, 60, 40)
|
|
||||||
Note over User,GW: 阶段4: 进度查询与结果下载
|
|
||||||
User->>+GW: GET /status/{job_id}
|
|
||||||
GW->>DB: SELECT COUNT(*) FROM tasks WHERE job_id=? GROUP BY status
|
|
||||||
DB-->>GW: {pending:50, processing:3, matting_done:2, comparing:2, completed:145}
|
|
||||||
GW-->>-User: 返回进度 {total:200, completed:145, progress:72.5%}
|
|
||||||
|
|
||||||
alt 所有任务完成
|
|
||||||
User->>+GW: GET /download/{job_id}
|
|
||||||
GW->>DB: SELECT * FROM tasks WHERE job_id=? AND status='completed'
|
|
||||||
GW->>FS: 打包 /final_results/chunk_*.jpg
|
|
||||||
GW-->>-User: 返回压缩包下载链接
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
%% 错误处理
|
|
||||||
rect rgb(80, 20, 20)
|
|
||||||
Note over MW,GW: 异常处理与重试
|
|
||||||
MW->>MW: 处理失败 (OOM 或其他错误)
|
|
||||||
MW->>DB: UPDATE status='failed', error='CUDA out of memory'
|
|
||||||
MW-->>MAPI: 返回 {status: 'error', message: 'OOM'}
|
|
||||||
MAPI-->>GW: 返回错误信息
|
|
||||||
|
|
||||||
GW->>DB: SELECT retry_count WHERE task_id='chunk_001'
|
|
||||||
alt retry_count < 3
|
|
||||||
GW->>DB: UPDATE retry_count=retry_count+1, status='pending'
|
|
||||||
GW->>MAPI: 重新调度任务
|
|
||||||
else retry_count >= 3
|
|
||||||
GW->>DB: UPDATE status='permanently_failed'
|
|
||||||
GW->>User: 发送失败通知
|
|
||||||
end
|
|
||||||
end
|
|
||||||
```
|
|
||||||
|
|
||||||
docker-compose.yml
|
|
||||||
```yaml
|
|
||||||
version: '3.8'
|
|
||||||
|
|
||||||
services:
|
|
||||||
# API Gateway 服务
|
|
||||||
gateway:
|
|
||||||
image: your-registry/gateway:latest
|
|
||||||
container_name: gateway
|
|
||||||
ports:
|
|
||||||
- "8000:8000"
|
|
||||||
volumes:
|
|
||||||
- shared-data:/app/data
|
|
||||||
environment:
|
|
||||||
- MATTING_API_URL=http://matting-api:8001
|
|
||||||
- COMPARISON_API_URL=http://comparison-api:8002
|
|
||||||
- MAX_UPLOAD_SIZE=107374182400 # 100GB
|
|
||||||
depends_on:
|
|
||||||
- matting-api
|
|
||||||
- comparison-api
|
|
||||||
restart: unless-stopped
|
|
||||||
networks:
|
|
||||||
- app-network
|
|
||||||
|
|
||||||
# Matting 统一 API (负载均衡入口)
|
|
||||||
matting-api:
|
|
||||||
image: your-registry/matting-api:latest
|
|
||||||
container_name: matting-api
|
|
||||||
ports:
|
|
||||||
- "8001:8001"
|
|
||||||
volumes:
|
|
||||||
- shared-data:/app/data
|
|
||||||
environment:
|
|
||||||
- WORKER_URLS=http://matting-worker-1:9001,http://matting-worker-2:9001,http://matting-worker-3:9001
|
|
||||||
depends_on:
|
|
||||||
- matting-worker-1
|
|
||||||
- matting-worker-2
|
|
||||||
- matting-worker-3
|
|
||||||
restart: unless-stopped
|
|
||||||
networks:
|
|
||||||
- app-network
|
|
||||||
|
|
||||||
# Matting Worker 1 (共享 GPU)
|
|
||||||
matting-worker-1:
|
|
||||||
image: your-registry/matting-worker:latest
|
|
||||||
container_name: matting-worker-1
|
|
||||||
volumes:
|
|
||||||
- shared-data:/app/data
|
|
||||||
environment:
|
|
||||||
- CUDA_VISIBLE_DEVICES=0
|
|
||||||
- WORKER_ID=1
|
|
||||||
- GPU_MEMORY_FRACTION=0.15 # 限制显存使用比例
|
|
||||||
deploy:
|
|
||||||
resources:
|
|
||||||
reservations:
|
|
||||||
devices:
|
|
||||||
- driver: nvidia
|
|
||||||
count: 1
|
|
||||||
capabilities: [gpu]
|
|
||||||
restart: unless-stopped
|
|
||||||
networks:
|
|
||||||
- app-network
|
|
||||||
|
|
||||||
# Matting Worker 2 (共享 GPU)
|
|
||||||
matting-worker-2:
|
|
||||||
image: your-registry/matting-worker:latest
|
|
||||||
container_name: matting-worker-2
|
|
||||||
volumes:
|
|
||||||
- shared-data:/app/data
|
|
||||||
environment:
|
|
||||||
- CUDA_VISIBLE_DEVICES=0
|
|
||||||
- WORKER_ID=2
|
|
||||||
- GPU_MEMORY_FRACTION=0.15
|
|
||||||
deploy:
|
|
||||||
resources:
|
|
||||||
reservations:
|
|
||||||
devices:
|
|
||||||
- driver: nvidia
|
|
||||||
count: 1
|
|
||||||
capabilities: [gpu]
|
|
||||||
restart: unless-stopped
|
|
||||||
networks:
|
|
||||||
- app-network
|
|
||||||
|
|
||||||
# Matting Worker 3 (共享 GPU)
|
|
||||||
matting-worker-3:
|
|
||||||
image: your-registry/matting-worker:latest
|
|
||||||
container_name: matting-worker-3
|
|
||||||
volumes:
|
|
||||||
- shared-data:/app/data
|
|
||||||
environment:
|
|
||||||
- CUDA_VISIBLE_DEVICES=0
|
|
||||||
- WORKER_ID=3
|
|
||||||
- GPU_MEMORY_FRACTION=0.15
|
|
||||||
deploy:
|
|
||||||
resources:
|
|
||||||
reservations:
|
|
||||||
devices:
|
|
||||||
- driver: nvidia
|
|
||||||
count: 1
|
|
||||||
capabilities: [gpu]
|
|
||||||
restart: unless-stopped
|
|
||||||
networks:
|
|
||||||
- app-network
|
|
||||||
|
|
||||||
# Comparison 统一 API (负载均衡入口)
|
|
||||||
comparison-api:
|
|
||||||
image: your-registry/comparison-api:latest
|
|
||||||
container_name: comparison-api
|
|
||||||
ports:
|
|
||||||
- "8002:8002"
|
|
||||||
volumes:
|
|
||||||
- shared-data:/app/data
|
|
||||||
environment:
|
|
||||||
- WORKER_URLS=http://comparison-worker-1:9002,http://comparison-worker-2:9002
|
|
||||||
depends_on:
|
|
||||||
- comparison-worker-1
|
|
||||||
- comparison-worker-2
|
|
||||||
restart: unless-stopped
|
|
||||||
networks:
|
|
||||||
- app-network
|
|
||||||
|
|
||||||
# Comparison Worker 1 (共享 GPU)
|
|
||||||
comparison-worker-1:
|
|
||||||
image: your-registry/comparison-worker:latest
|
|
||||||
container_name: comparison-worker-1
|
|
||||||
volumes:
|
|
||||||
- shared-data:/app/data
|
|
||||||
environment:
|
|
||||||
- CUDA_VISIBLE_DEVICES=0
|
|
||||||
- WORKER_ID=1
|
|
||||||
- GPU_MEMORY_FRACTION=0.15
|
|
||||||
deploy:
|
|
||||||
resources:
|
|
||||||
reservations:
|
|
||||||
devices:
|
|
||||||
- driver: nvidia
|
|
||||||
count: 1
|
|
||||||
capabilities: [gpu]
|
|
||||||
restart: unless-stopped
|
|
||||||
networks:
|
|
||||||
- app-network
|
|
||||||
|
|
||||||
# Comparison Worker 2 (共享 GPU)
|
|
||||||
comparison-worker-2:
|
|
||||||
image: your-registry/comparison-worker:latest
|
|
||||||
container_name: comparison-worker-2
|
|
||||||
volumes:
|
|
||||||
- shared-data:/app/data
|
|
||||||
environment:
|
|
||||||
- CUDA_VISIBLE_DEVICES=0
|
|
||||||
- WORKER_ID=2
|
|
||||||
- GPU_MEMORY_FRACTION=0.15
|
|
||||||
deploy:
|
|
||||||
resources:
|
|
||||||
reservations:
|
|
||||||
devices:
|
|
||||||
- driver: nvidia
|
|
||||||
count: 1
|
|
||||||
capabilities: [gpu]
|
|
||||||
restart: unless-stopped
|
|
||||||
networks:
|
|
||||||
- app-network
|
|
||||||
|
|
||||||
# 共享数据卷
|
|
||||||
volumes:
|
|
||||||
shared-data:
|
|
||||||
driver: local
|
|
||||||
driver_opts:
|
|
||||||
type: none
|
|
||||||
o: bind
|
|
||||||
device: /data/app-storage # 宿主机路径
|
|
||||||
|
|
||||||
# 网络配置
|
|
||||||
networks:
|
|
||||||
app-network:
|
|
||||||
driver: bridge
|
|
||||||
```
|
|
||||||
376
app.py
376
app.py
@ -1,192 +1,198 @@
|
|||||||
import streamlit as st
|
import argparse
|
||||||
import subprocess
|
import torch
|
||||||
|
import cv2
|
||||||
import os
|
import os
|
||||||
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from sam3 import build_sam3_image_model
|
||||||
|
from sam3.model.sam3_image_processor import Sam3Processor
|
||||||
|
|
||||||
# === 页面基础配置 ===
|
# 支持的图片扩展名
|
||||||
st.set_page_config(layout="wide", page_title="多核违建检测平台 v3.0")
|
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff'}
|
||||||
|
|
||||||
st.title("🚁 多核无人机违建检测平台 v3.0")
|
def process_one_image(processor, img_path, args, output_dir):
|
||||||
st.caption("集成内核: DINOv2 Giant | DiffSim-Pro v2.1 | DreamSim Ensemble")
|
filename = os.path.basename(img_path)
|
||||||
|
print(f"\n[INFO] 正在处理: {filename} ...")
|
||||||
# ==========================================
|
|
||||||
# 1. 侧边栏:算法核心选择
|
|
||||||
# ==========================================
|
|
||||||
st.sidebar.header("🧠 算法内核 (Core)")
|
|
||||||
|
|
||||||
algo_type = st.sidebar.radio(
|
|
||||||
"选择检测模式",
|
|
||||||
(
|
|
||||||
"DreamSim(语义/感知) 🔥",
|
|
||||||
#"DINOv2 Giant (切片/精度)",
|
|
||||||
#"DiffSim (结构/抗视差)",
|
|
||||||
#"DINOv2 Giant (全图/感知)"
|
|
||||||
),
|
|
||||||
index=0, # 默认选中 DINOv2 切片版,因为它是目前效果最好的
|
|
||||||
help="DINOv2: 几何结构最敏感,抗光照干扰。\nDiffSim: 可调参数多,适合微调。\nDreamSim: 关注整体风格差异。"
|
|
||||||
)
|
|
||||||
|
|
||||||
st.sidebar.markdown("---")
|
|
||||||
st.sidebar.header("🛠️ 参数配置")
|
|
||||||
|
|
||||||
# 初始化默认参数变量
|
|
||||||
script_name = ""
|
|
||||||
cmd_extra_args = []
|
|
||||||
show_slice_params = True # 默认显示切片参数
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 2. 根据选择配置参数
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
if "DINOv2" in algo_type:
|
|
||||||
# === DINOv2 通用配置 ===
|
|
||||||
st.sidebar.info(f"当前内核: {algo_type.split(' ')[0]}")
|
|
||||||
|
|
||||||
# 阈值控制 (DINO 的差异值通常较小,默认给 0.3)
|
|
||||||
thresh = st.sidebar.slider("敏感度阈值 (Thresh)", 0.0, 1.0, 0.30, 0.01, help="值越小越敏感,值越大越只看显著变化")
|
|
||||||
|
|
||||||
# 区分全图 vs 切片
|
|
||||||
if "切片" in algo_type:
|
|
||||||
script_name = "main_dinov2_sliced.py"
|
|
||||||
show_slice_params = True
|
|
||||||
st.sidebar.caption("✅ 适用场景:4K/8K 大图,寻找细微违建。")
|
|
||||||
else:
|
|
||||||
script_name = "main_dinov2.py" # 对应你之前的 main_dinov2_giant.py (全图版)
|
|
||||||
show_slice_params = False # 全图模式不需要调切片
|
|
||||||
st.sidebar.caption("⚡ 适用场景:快速扫描,显存充足,寻找大面积变化。")
|
|
||||||
|
|
||||||
elif "DiffSim" in algo_type:
|
|
||||||
# === DiffSim 配置 ===
|
|
||||||
script_name = "main_finally.py"
|
|
||||||
show_slice_params = True # DiffSim 需要切片
|
|
||||||
|
|
||||||
st.sidebar.subheader("1. 感知权重")
|
|
||||||
w_struct = st.sidebar.slider("结构权重 (Struct)", 0.0, 1.0, 0.3, 0.05)
|
|
||||||
w_sem = st.sidebar.slider("语义权重 (Sem)", 0.0, 1.0, 0.7, 0.05)
|
|
||||||
w_tex = st.sidebar.slider("纹理权重 (Texture)", 0.0, 1.0, 0.0, 0.05)
|
|
||||||
|
|
||||||
st.sidebar.subheader("2. 信号处理")
|
|
||||||
kernel = st.sidebar.number_input("抗视差窗口 (Kernel)", value=5, step=2)
|
|
||||||
gamma = st.sidebar.slider("Gamma 压制", 0.5, 4.0, 1.0, 0.1)
|
|
||||||
thresh = st.sidebar.slider("可视化阈值", 0.0, 1.0, 0.15, 0.01)
|
|
||||||
|
|
||||||
# 封装 DiffSim 独有的参数
|
|
||||||
cmd_extra_args = [
|
|
||||||
"--model", "Manojb/stable-diffusion-2-1-base",
|
|
||||||
"--w_struct", str(w_struct),
|
|
||||||
"--w_sem", str(w_sem),
|
|
||||||
"--w_tex", str(w_tex),
|
|
||||||
"--gamma", str(gamma),
|
|
||||||
"--kernel", str(kernel)
|
|
||||||
]
|
|
||||||
|
|
||||||
else: # DreamSim
|
|
||||||
# === DreamSim 配置 ===
|
|
||||||
script_name = "main_dreamsim.py"
|
|
||||||
show_slice_params = True
|
|
||||||
|
|
||||||
st.sidebar.subheader("阈值控制")
|
|
||||||
thresh = st.sidebar.slider("可视化阈值 (Thresh)", 0.0, 1.0, 0.3, 0.01)
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 3. 公共参数 (切片策略)
|
|
||||||
# ==========================================
|
|
||||||
if show_slice_params:
|
|
||||||
st.sidebar.subheader("🚀 扫描策略")
|
|
||||||
# DINOv2 切片版建议 Batch 小一点
|
|
||||||
default_batch = 8 if "DINOv2" in algo_type else 16
|
|
||||||
|
|
||||||
crop_size = st.sidebar.number_input("切片大小 (Crop)", value=224)
|
|
||||||
step_size = st.sidebar.number_input("步长 (Step, 0=自动)", value=0)
|
|
||||||
batch_size = st.sidebar.number_input("批次大小 (Batch)", value=default_batch)
|
|
||||||
else:
|
|
||||||
# 全图模式,隐藏参数但保留变量防报错
|
|
||||||
st.sidebar.success("全图模式:无需切片设置 (One-Shot)")
|
|
||||||
crop_size, step_size, batch_size = 224, 0, 1
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 4. 主界面:图片上传与执行
|
|
||||||
# ==========================================
|
|
||||||
col1, col2 = st.columns(2)
|
|
||||||
with col1:
|
|
||||||
file_t1 = st.file_uploader("上传基准图 (Base / Old)", type=["jpg","png","jpeg"], key="t1")
|
|
||||||
if file_t1: st.image(file_t1, use_column_width=True)
|
|
||||||
with col2:
|
|
||||||
file_t2 = st.file_uploader("上传现状图 (Current / New)", type=["jpg","png","jpeg"], key="t2")
|
|
||||||
if file_t2: st.image(file_t2, use_column_width=True)
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
# 启动按钮
|
|
||||||
if st.button("🚀 启动检测内核", type="primary", use_container_width=True):
|
|
||||||
if not file_t1 or not file_t2:
|
|
||||||
st.error("请先上传两张图片!")
|
|
||||||
else:
|
|
||||||
# 1. 保存临时文件
|
|
||||||
os.makedirs("temp_uploads", exist_ok=True)
|
|
||||||
t1_path = os.path.join("temp_uploads", "t1.jpg")
|
|
||||||
t2_path = os.path.join("temp_uploads", "t2.jpg")
|
|
||||||
|
|
||||||
# 结果文件名根据算法区分,防止缓存混淆
|
|
||||||
result_name = f"result_{script_name.replace('.py', '')}.jpg"
|
|
||||||
out_path = os.path.join("temp_uploads", result_name)
|
|
||||||
|
|
||||||
with open(t1_path, "wb") as f: f.write(file_t1.getbuffer())
|
|
||||||
with open(t2_path, "wb") as f: f.write(file_t2.getbuffer())
|
|
||||||
|
|
||||||
# 2. 构建命令
|
|
||||||
# 基础命令: python3 script.py t1 t2 out
|
|
||||||
cmd = ["python3", script_name, t1_path, t2_path, out_path]
|
|
||||||
|
|
||||||
# 添加通用参数 (所有脚本都兼容 -c -s -b --thresh 格式)
|
|
||||||
# 注意:即使全图模式脚本忽略 -c -s,传进去也不会报错,保持逻辑简单
|
|
||||||
cmd.extend([
|
|
||||||
"--crop", str(crop_size),
|
|
||||||
"--step", str(step_size),
|
|
||||||
"--batch", str(batch_size),
|
|
||||||
"--thresh", str(thresh)
|
|
||||||
])
|
|
||||||
|
|
||||||
# 添加特定算法参数 (DiffSim)
|
|
||||||
if cmd_extra_args:
|
|
||||||
cmd.extend(cmd_extra_args)
|
|
||||||
|
|
||||||
# 3. 显示状态与运行
|
|
||||||
st.info(f"⏳ 正在调用内核: `{script_name}` ...")
|
|
||||||
st.text(f"执行命令: {' '.join(cmd)}") # 方便调试
|
|
||||||
|
|
||||||
|
# --- 1. 读取图片 ---
|
||||||
try:
|
try:
|
||||||
# 实时显示进度条可能比较难,这里用 spinner
|
# 读取原始图片用于推理
|
||||||
with st.spinner('AI 正在进行特征提取与比对... (DINO Giant 可能需要几秒钟)'):
|
pil_image = Image.open(img_path).convert("RGB")
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
# 读取 OpenCV 格式用于后续处理和最终输出背景
|
||||||
|
orig_cv2_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||||||
# 4. 结果处理
|
|
||||||
# 展开日志供查看
|
|
||||||
with st.expander("📄 查看内核运行日志", expanded=(result.returncode != 0)):
|
|
||||||
if result.stdout: st.code(result.stdout, language="bash")
|
|
||||||
if result.stderr: st.error(result.stderr)
|
|
||||||
|
|
||||||
if result.returncode == 0:
|
|
||||||
st.success(f"✅ 检测完成!耗时逻辑已结束。")
|
|
||||||
|
|
||||||
# 结果展示区
|
|
||||||
r_col1, r_col2 = st.columns(2)
|
|
||||||
with r_col1:
|
|
||||||
# 尝试读取调试热力图 (如果有的话)
|
|
||||||
if os.path.exists("debug_raw_heatmap.png"):
|
|
||||||
st.image("debug_raw_heatmap.png", caption="🔍 原始差异热力图 (Debug)", use_column_width=True)
|
|
||||||
else:
|
|
||||||
st.warning("无调试热力图生成")
|
|
||||||
|
|
||||||
with r_col2:
|
|
||||||
if os.path.exists(out_path):
|
|
||||||
# 使用 PIL 打开以强制刷新缓存 (Streamlit 有时会缓存同名图片)
|
|
||||||
res_img = Image.open(out_path)
|
|
||||||
st.image(res_img, caption=f"🎯 最终检测结果 ({algo_type})", use_column_width=True)
|
|
||||||
else:
|
|
||||||
st.error(f"❌ 未找到输出文件: {out_path}")
|
|
||||||
else:
|
|
||||||
st.error("❌ 内核运行出错,请检查上方日志。")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
st.error(f"系统错误: {e}")
|
print(f"[Error] 无法读取图片 {filename}: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- 2. 推理 (逻辑不变) ---
|
||||||
|
try:
|
||||||
|
inference_state = processor.set_image(pil_image)
|
||||||
|
results = processor.set_text_prompt(
|
||||||
|
state=inference_state,
|
||||||
|
prompt=args.text
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Error] 推理报错 ({filename}): {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- 3. 结果过滤 (逻辑不变) ---
|
||||||
|
masks = results.get("masks")
|
||||||
|
scores = results.get("scores")
|
||||||
|
boxes = results.get("boxes")
|
||||||
|
|
||||||
|
if masks is None:
|
||||||
|
print(f"[WARN] {filename} 未检测到任何目标。")
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(masks, torch.Tensor): masks = masks.cpu().numpy()
|
||||||
|
if isinstance(scores, torch.Tensor): scores = scores.cpu().numpy()
|
||||||
|
if isinstance(boxes, torch.Tensor): boxes = boxes.cpu().numpy()
|
||||||
|
|
||||||
|
keep_indices = np.where(scores > args.conf)[0]
|
||||||
|
if len(keep_indices) == 0:
|
||||||
|
print(f"[WARN] {filename} 没有目标超过阈值 {args.conf}。")
|
||||||
|
return
|
||||||
|
|
||||||
|
masks = masks[keep_indices]
|
||||||
|
scores = scores[keep_indices]
|
||||||
|
boxes = boxes[keep_indices]
|
||||||
|
print(f"[INFO] 发现 {len(masks)} 个目标 (Text: {args.text})")
|
||||||
|
|
||||||
|
# --- 4. 掩码处理与合并 (核心修改区域) ---
|
||||||
|
img_h, img_w = orig_cv2_image.shape[:2]
|
||||||
|
|
||||||
|
# 【核心修改点 1】创建一个全黑的单通道“总掩码”,用于累积所有检测对象的区域
|
||||||
|
total_mask = np.zeros((img_h, img_w), dtype=np.uint8)
|
||||||
|
|
||||||
|
# 循环处理每个检测到的目标 (保持原有的形态学处理逻辑不变)
|
||||||
|
for i in range(len(masks)):
|
||||||
|
m = masks[i]
|
||||||
|
box = boxes[i]
|
||||||
|
|
||||||
|
# --- A. 基础预处理 (保持不变) ---
|
||||||
|
if m.ndim > 2: m = m[0]
|
||||||
|
# Resize 到原图大小
|
||||||
|
if m.shape != (img_h, img_w):
|
||||||
|
m = cv2.resize(m.astype(np.uint8), (img_w, img_h), interpolation=cv2.INTER_NEAREST)
|
||||||
|
m = m.astype(np.uint8) # 确保是 uint8 (0, 1)
|
||||||
|
|
||||||
|
# --- B. 后处理流程 (保持不变) ---
|
||||||
|
# 1. 基础膨胀 (全局)
|
||||||
|
if args.dilate > 0:
|
||||||
|
k_size = args.dilate if args.dilate % 2 == 1 else args.dilate + 1
|
||||||
|
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_size, k_size))
|
||||||
|
m = cv2.dilate(m, kernel, iterations=1)
|
||||||
|
|
||||||
|
# 2. 矩形区域填充 (Rect Fill) - 含 Padding 逻辑
|
||||||
|
if args.rect_fill:
|
||||||
|
x1, y1, x2, y2 = map(int, box)
|
||||||
|
# 应用检测框范围调整 (Padding)
|
||||||
|
pad = args.bbox_pad
|
||||||
|
x1, y1 = max(0, x1 - pad), max(0, y1 - pad)
|
||||||
|
x2, y2 = min(img_w, x2 + pad), min(img_h, y2 + pad)
|
||||||
|
|
||||||
|
if x2 > x1 and y2 > y1:
|
||||||
|
roi = m[y1:y2, x1:x2]
|
||||||
|
k_rect = args.rect_kernel
|
||||||
|
kernel_rect = cv2.getStructuringElement(cv2.MORPH_RECT, (k_rect, k_rect))
|
||||||
|
roi_closed = cv2.morphologyEx(roi, cv2.MORPH_CLOSE, kernel_rect)
|
||||||
|
m[y1:y2, x1:x2] = roi_closed
|
||||||
|
|
||||||
|
# 3. 孔洞填充 (Hole Filling)
|
||||||
|
if args.fill_holes:
|
||||||
|
contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
m_filled = np.zeros_like(m)
|
||||||
|
cv2.drawContours(m_filled, contours, -1, 1, thickness=cv2.FILLED)
|
||||||
|
m = m_filled
|
||||||
|
|
||||||
|
# 【核心修改点 2】不再绘制彩色图层和边框,而是将处理好的单个掩码合并到总掩码中
|
||||||
|
# 使用按位或 (OR) 操作:只要任意一个目标的 Mask 在某个像素是白的,总 Mask 在这里就是白的
|
||||||
|
total_mask = cv2.bitwise_or(total_mask, m)
|
||||||
|
|
||||||
|
# --- 5. 生成最终图像 (Blackout 效果) ---
|
||||||
|
# 【核心修改点 3】利用总掩码,将原图中掩码为黑色(0)的区域涂黑,掩码为白色(255)的区域保留原像素
|
||||||
|
# 确保 total_mask 是 0 或 255
|
||||||
|
total_mask_255 = (total_mask * 255).astype(np.uint8)
|
||||||
|
final_image = cv2.bitwise_and(orig_cv2_image, orig_cv2_image, mask=total_mask_255)
|
||||||
|
|
||||||
|
# --- 6. 保存 ---
|
||||||
|
tags = []
|
||||||
|
# 添加一个标记表示是黑底抠图结果
|
||||||
|
tags.append("blackout")
|
||||||
|
if args.dilate > 0: tags.append(f"d{args.dilate}")
|
||||||
|
if args.rect_fill: tags.append(f"rect{args.rect_kernel}")
|
||||||
|
if args.bbox_pad != 0: tags.append(f"pad{args.bbox_pad}")
|
||||||
|
if args.fill_holes: tags.append("filled")
|
||||||
|
|
||||||
|
tag_str = "_" + "_".join(tags) if tags else ""
|
||||||
|
name_no_ext = os.path.splitext(filename)[0]
|
||||||
|
# 保存为 jpg 即可,因为背景是纯黑不是透明
|
||||||
|
save_filename = f"{name_no_ext}_{args.text.replace(' ', '_')}{tag_str}.jpg"
|
||||||
|
save_path = os.path.join(output_dir, save_filename)
|
||||||
|
|
||||||
|
cv2.imwrite(save_path, final_image)
|
||||||
|
print(f"[SUCCESS] 结果已保存至: {save_path}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# (主函数的参数解析和加载逻辑完全保持不变)
|
||||||
|
parser = argparse.ArgumentParser(description="SAM 3 自动抠图脚本 (黑底保留原色)")
|
||||||
|
parser.add_argument("--input", type=str, required=True, help="输入路径")
|
||||||
|
parser.add_argument("--output", type=str, default="/app/outputs", help="输出目录")
|
||||||
|
parser.add_argument("--model", type=str, default="/app/checkpoints/sam3.pt", help="模型路径")
|
||||||
|
parser.add_argument("--text", type=str, required=True, help="提示词")
|
||||||
|
parser.add_argument("--conf", type=float, default=0.2, help="置信度")
|
||||||
|
|
||||||
|
# 后处理参数
|
||||||
|
parser.add_argument("--dilate", type=int, default=0, help="全局基础膨胀 (px)")
|
||||||
|
parser.add_argument("--fill-holes", action="store_true", help="开启孔洞填充")
|
||||||
|
parser.add_argument("--rect-fill", action="store_true", help="开启矩形区域智能填充")
|
||||||
|
parser.add_argument("--rect-kernel", type=int, default=20, help="矩形填充的核大小")
|
||||||
|
parser.add_argument("--bbox-pad", type=int, default=0, help="检测框范围调整 (px)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
os.makedirs(args.output, exist_ok=True)
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print(f"--- 运行配置 (黑底模式) ---")
|
||||||
|
print(f"输入: {args.input}")
|
||||||
|
print(f"模式: {'矩形填充' if args.rect_fill else '普通'}")
|
||||||
|
if args.rect_fill:
|
||||||
|
print(f"矩形核: {args.rect_kernel} px")
|
||||||
|
print(f"检测框Padding: {args.bbox_pad} px")
|
||||||
|
print(f"----------------")
|
||||||
|
|
||||||
|
image_list = []
|
||||||
|
if os.path.isdir(args.input):
|
||||||
|
for root, _, files in os.walk(args.input):
|
||||||
|
for file in files:
|
||||||
|
if os.path.splitext(file)[1].lower() in IMG_EXTENSIONS:
|
||||||
|
image_list.append(os.path.join(root, file))
|
||||||
|
elif os.path.isfile(args.input):
|
||||||
|
image_list.append(args.input)
|
||||||
|
|
||||||
|
if not image_list:
|
||||||
|
print("[Error] 未找到图片")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"[INFO] 加载模型: {args.model} ...")
|
||||||
|
try:
|
||||||
|
sam3_model = build_sam3_image_model(
|
||||||
|
checkpoint_path=args.model,
|
||||||
|
load_from_HF=False,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
processor = Sam3Processor(sam3_model)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Fatal] 模型加载失败: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
total = len(image_list)
|
||||||
|
for idx, img_path in enumerate(image_list):
|
||||||
|
print(f"\n--- [{idx+1}/{total}] ---")
|
||||||
|
process_one_image(processor, img_path, args, args.output)
|
||||||
|
|
||||||
|
print("\n[DONE] 完成。")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|||||||
@ -1,25 +1,8 @@
|
|||||||
services:
|
services:
|
||||||
diffsim-runner:
|
sam3:
|
||||||
build: .
|
build: .
|
||||||
container_name: diffsim_runner
|
container_name: sam3_fixed
|
||||||
# runtime: nvidia
|
# 启用 GPU
|
||||||
|
|
||||||
# 🔥 新增:端口映射
|
|
||||||
# 左边是宿主机端口,右边是容器端口 (Streamlit 默认 8501)
|
|
||||||
ports:
|
|
||||||
- "8601:8501"
|
|
||||||
|
|
||||||
# 挂载目录:代码、数据、模型缓存
|
|
||||||
volumes:
|
|
||||||
- .:/app # 当前目录代码映射到容器 /app
|
|
||||||
- ./data:/app/data # 图片数据映射
|
|
||||||
- ./hf_cache:/root/.cache/huggingface # 模型缓存映射
|
|
||||||
|
|
||||||
environment:
|
|
||||||
- NVIDIA_VISIBLE_DEVICES=all
|
|
||||||
# 🔥 新增:设置国内 HF 镜像,确保每次启动容器都能由镜像站加速
|
|
||||||
- HF_ENDPOINT=https://hf-mirror.com
|
|
||||||
|
|
||||||
deploy:
|
deploy:
|
||||||
resources:
|
resources:
|
||||||
reservations:
|
reservations:
|
||||||
@ -28,8 +11,23 @@ services:
|
|||||||
count: 1
|
count: 1
|
||||||
capabilities: [gpu]
|
capabilities: [gpu]
|
||||||
|
|
||||||
# 让容器启动后不退出,像一个虚拟机一样待命
|
# 共享内存优化
|
||||||
# 你可以随时 docker exec 进去,然后手动运行 streamlit run app.py
|
shm_size: '16gb'
|
||||||
command: tail -f /dev/null
|
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
# 👇 这里是最关键的修正 👇
|
||||||
|
- ./checkpoints:/app/checkpoints # 确保容器能读到你刚刚救出来的 .pt 文件
|
||||||
|
- ./data:/app/data # 数据映射
|
||||||
|
- ./outputs:/app/outputs # 输出结果
|
||||||
|
|
||||||
|
environment:
|
||||||
|
- NVIDIA_VISIBLE_DEVICES=all
|
||||||
|
- NVIDIA_DRIVER_CAPABILITIES=compute,utility,video
|
||||||
|
|
||||||
|
ports:
|
||||||
|
- "7860:7860" # Gradio 端口
|
||||||
|
- "8888:8888" # Jupyter 端口 (保留着也没事)
|
||||||
|
|
||||||
|
stdin_open: true
|
||||||
|
tty: true
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|||||||
278
main.py
278
main.py
@ -1,278 +0,0 @@
|
|||||||
import os
|
|
||||||
# 🔥 强制设置 HF 镜像 (必须放在最前面)
|
|
||||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import argparse
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
from diffusers import StableDiffusionPipeline
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# === 配置 ===
|
|
||||||
# 使用 SD 1.5,无需鉴权,且对小切片纹理更敏感
|
|
||||||
MODEL_ID = "runwayml/stable-diffusion-v1-5"
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
THRESHOLD = 0.35
|
|
||||||
IMG_RESIZE = 224
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 🔥 核心:DiffSim Pro 模型定义 (修复版)
|
|
||||||
# ==========================================
|
|
||||||
class DiffSimPro(nn.Module):
|
|
||||||
def __init__(self, device):
|
|
||||||
super().__init__()
|
|
||||||
print(f"🚀 [系统] 初始化 DiffSim Pro (基于 {MODEL_ID})...")
|
|
||||||
|
|
||||||
if device == "cuda":
|
|
||||||
print(f"✅ [硬件确认] 正在使用显卡: {torch.cuda.get_device_name(0)}")
|
|
||||||
else:
|
|
||||||
print("❌ [警告] 未检测到显卡,正在使用 CPU 慢速运行!")
|
|
||||||
|
|
||||||
# 1. 加载 SD 模型
|
|
||||||
self.pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16).to(device)
|
|
||||||
self.pipe.set_progress_bar_config(disable=True)
|
|
||||||
|
|
||||||
# 冻结参数
|
|
||||||
self.pipe.vae.requires_grad_(False)
|
|
||||||
self.pipe.unet.requires_grad_(False)
|
|
||||||
self.pipe.text_encoder.requires_grad_(False)
|
|
||||||
|
|
||||||
# 🔥【修复逻辑】:预先计算“空文本”的 Embedding
|
|
||||||
# UNet 必须要有这个 encoder_hidden_states 参数才能运行
|
|
||||||
with torch.no_grad():
|
|
||||||
prompt = ""
|
|
||||||
text_inputs = self.pipe.tokenizer(
|
|
||||||
prompt,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=self.pipe.tokenizer.model_max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
text_input_ids = text_inputs.input_ids.to(device)
|
|
||||||
# 获取空文本特征 [1, 77, 768]
|
|
||||||
self.empty_text_embeds = self.pipe.text_encoder(text_input_ids)[0]
|
|
||||||
|
|
||||||
# 2. 定义特征容器和 Hooks
|
|
||||||
self.features = {}
|
|
||||||
|
|
||||||
# 注册 Hooks:抓取 纹理(1)、结构(2)、语义(3)
|
|
||||||
for name, layer in self.pipe.unet.named_modules():
|
|
||||||
if "up_blocks.1" in name and name.endswith("resnets.2"):
|
|
||||||
layer.register_forward_hook(self.get_hook("feat_high"))
|
|
||||||
elif "up_blocks.2" in name and name.endswith("resnets.2"):
|
|
||||||
layer.register_forward_hook(self.get_hook("feat_mid"))
|
|
||||||
elif "up_blocks.3" in name and name.endswith("resnets.2"):
|
|
||||||
layer.register_forward_hook(self.get_hook("feat_low"))
|
|
||||||
|
|
||||||
def get_hook(self, name):
|
|
||||||
def hook(model, input, output):
|
|
||||||
self.features[name] = output
|
|
||||||
return hook
|
|
||||||
|
|
||||||
def extract_features(self, images):
|
|
||||||
""" VAE Encode -> UNet Forward -> Hook Features """
|
|
||||||
# 1. VAE 编码
|
|
||||||
latents = self.pipe.vae.encode(images).latent_dist.sample() * self.pipe.vae.config.scaling_factor
|
|
||||||
|
|
||||||
# 2. 准备参数
|
|
||||||
batch_size = latents.shape[0]
|
|
||||||
t = torch.zeros(batch_size, device=DEVICE, dtype=torch.long)
|
|
||||||
|
|
||||||
# 🔥【修复逻辑】:将空文本 Embedding 扩展到当前 Batch 大小
|
|
||||||
# 形状变为 [batch_size, 77, 768]
|
|
||||||
encoder_hidden_states = self.empty_text_embeds.expand(batch_size, -1, -1)
|
|
||||||
|
|
||||||
# 3. UNet 前向传播 (带上 encoder_hidden_states)
|
|
||||||
self.pipe.unet(latents, t, encoder_hidden_states=encoder_hidden_states)
|
|
||||||
|
|
||||||
return {k: v.clone() for k, v in self.features.items()}
|
|
||||||
|
|
||||||
def robust_similarity(self, f1, f2, kernel_size=3):
|
|
||||||
""" 抗视差匹配算法 """
|
|
||||||
f1 = F.normalize(f1, dim=1)
|
|
||||||
f2 = F.normalize(f2, dim=1)
|
|
||||||
|
|
||||||
padding = kernel_size // 2
|
|
||||||
b, c, h, w = f2.shape
|
|
||||||
|
|
||||||
f2_unfolded = F.unfold(f2, kernel_size=kernel_size, padding=padding)
|
|
||||||
f2_unfolded = f2_unfolded.view(b, c, kernel_size*kernel_size, h, w)
|
|
||||||
|
|
||||||
sim_map = (f1.unsqueeze(2) * f2_unfolded).sum(dim=1)
|
|
||||||
max_sim, _ = sim_map.max(dim=1)
|
|
||||||
|
|
||||||
return max_sim
|
|
||||||
|
|
||||||
def compute_batch_distance(self, batch_p1, batch_p2):
|
|
||||||
feat_a = self.extract_features(batch_p1)
|
|
||||||
feat_b = self.extract_features(batch_p2)
|
|
||||||
|
|
||||||
total_score = 0
|
|
||||||
# 权重:结构层(mid)最重要
|
|
||||||
weights = {"feat_high": 0.2, "feat_mid": 0.5, "feat_low": 0.3}
|
|
||||||
|
|
||||||
for name, w in weights.items():
|
|
||||||
fa, fb = feat_a[name].float(), feat_b[name].float()
|
|
||||||
|
|
||||||
if name == "feat_high":
|
|
||||||
sim_map = self.robust_similarity(fa, fb, kernel_size=3)
|
|
||||||
dist = 1 - sim_map.mean(dim=[1, 2])
|
|
||||||
else:
|
|
||||||
dist = 1 - F.cosine_similarity(fa.flatten(1), fb.flatten(1))
|
|
||||||
|
|
||||||
total_score += dist * w
|
|
||||||
|
|
||||||
return total_score
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 🛠️ 辅助函数 & 扫描逻辑 (保持不变)
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
def get_transforms():
|
|
||||||
return transforms.Compose([
|
|
||||||
transforms.Resize((IMG_RESIZE, IMG_RESIZE)),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.5], [0.5])
|
|
||||||
])
|
|
||||||
|
|
||||||
def scan_and_draw(model, t1_path, t2_path, output_path, patch_size, stride, batch_size):
|
|
||||||
# 1. OpenCV 读取
|
|
||||||
img1_cv = cv2.imread(t1_path)
|
|
||||||
img2_cv = cv2.imread(t2_path)
|
|
||||||
|
|
||||||
if img1_cv is None or img2_cv is None:
|
|
||||||
print("❌ 错误: 无法读取图片")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 强制 Resize 对齐
|
|
||||||
h, w = img2_cv.shape[:2]
|
|
||||||
img1_cv = cv2.resize(img1_cv, (w, h))
|
|
||||||
|
|
||||||
preprocess = get_transforms()
|
|
||||||
|
|
||||||
# 2. 准备滑动窗口
|
|
||||||
print(f"🔪 [切片] 开始扫描... 尺寸: {w}x{h}")
|
|
||||||
print(f" - 切片大小: {patch_size}, 步长: {stride}, 批次: {batch_size}")
|
|
||||||
|
|
||||||
patches1 = []
|
|
||||||
patches2 = []
|
|
||||||
coords = []
|
|
||||||
|
|
||||||
for y in range(0, h - patch_size + 1, stride):
|
|
||||||
for x in range(0, w - patch_size + 1, stride):
|
|
||||||
crop1 = img1_cv[y:y+patch_size, x:x+patch_size]
|
|
||||||
crop2 = img2_cv[y:y+patch_size, x:x+patch_size]
|
|
||||||
|
|
||||||
p1 = preprocess(Image.fromarray(cv2.cvtColor(crop1, cv2.COLOR_BGR2RGB)))
|
|
||||||
p2 = preprocess(Image.fromarray(cv2.cvtColor(crop2, cv2.COLOR_BGR2RGB)))
|
|
||||||
|
|
||||||
patches1.append(p1)
|
|
||||||
patches2.append(p2)
|
|
||||||
coords.append((x, y))
|
|
||||||
|
|
||||||
if not patches1:
|
|
||||||
print("⚠️ 图片太小,无法切片")
|
|
||||||
return
|
|
||||||
|
|
||||||
total_patches = len(patches1)
|
|
||||||
print(f"🧠 [推理] 共 {total_patches} 个切片,开始 DiffSim Pro 计算...")
|
|
||||||
|
|
||||||
all_distances = []
|
|
||||||
|
|
||||||
# 3. 批量推理
|
|
||||||
for i in tqdm(range(0, total_patches, batch_size), unit="batch"):
|
|
||||||
batch_p1 = torch.stack(patches1[i : i + batch_size]).to(DEVICE, dtype=torch.float16)
|
|
||||||
batch_p2 = torch.stack(patches2[i : i + batch_size]).to(DEVICE, dtype=torch.float16)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
dist_batch = model.compute_batch_distance(batch_p1, batch_p2)
|
|
||||||
all_distances.append(dist_batch.cpu())
|
|
||||||
|
|
||||||
distances = torch.cat(all_distances)
|
|
||||||
|
|
||||||
# 4. 生成原始热力数据
|
|
||||||
heatmap = np.zeros((h, w), dtype=np.float32)
|
|
||||||
count_map = np.zeros((h, w), dtype=np.float32)
|
|
||||||
max_score = 0
|
|
||||||
|
|
||||||
for idx, score in enumerate(distances):
|
|
||||||
val = score.item()
|
|
||||||
x, y = coords[idx]
|
|
||||||
if val > max_score: max_score = val
|
|
||||||
|
|
||||||
heatmap[y:y+patch_size, x:x+patch_size] += val
|
|
||||||
count_map[y:y+patch_size, x:x+patch_size] += 1
|
|
||||||
|
|
||||||
count_map[count_map == 0] = 1
|
|
||||||
heatmap_avg = heatmap / count_map
|
|
||||||
|
|
||||||
# 5. 后处理
|
|
||||||
norm_factor = max(max_score, 0.1)
|
|
||||||
heatmap_vis = (heatmap_avg / norm_factor * 255).clip(0, 255).astype(np.uint8)
|
|
||||||
heatmap_color = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET)
|
|
||||||
|
|
||||||
alpha = 0.4
|
|
||||||
beta = 1.0 - alpha
|
|
||||||
blended_img = cv2.addWeighted(img2_cv, alpha, heatmap_color, beta, 0)
|
|
||||||
|
|
||||||
# 画框
|
|
||||||
_, thresh = cv2.threshold(heatmap_vis, int(255 * THRESHOLD), 255, cv2.THRESH_BINARY)
|
|
||||||
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
|
|
||||||
result_img = blended_img.copy()
|
|
||||||
found_issue = False
|
|
||||||
|
|
||||||
for cnt in contours:
|
|
||||||
area = cv2.contourArea(cnt)
|
|
||||||
min_area = (patch_size * patch_size) * 0.05
|
|
||||||
|
|
||||||
if area > min_area:
|
|
||||||
found_issue = True
|
|
||||||
x, y, bw, bh = cv2.boundingRect(cnt)
|
|
||||||
|
|
||||||
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4)
|
|
||||||
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2)
|
|
||||||
|
|
||||||
roi_score = heatmap_avg[y:y+bh, x:x+bw].mean()
|
|
||||||
label = f"Diff: {roi_score:.2f}"
|
|
||||||
|
|
||||||
cv2.rectangle(result_img, (x, y-25), (x+130, y), (0,0,255), -1)
|
|
||||||
cv2.putText(result_img, label, (x+5, y-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
|
|
||||||
|
|
||||||
output_full_path = output_path
|
|
||||||
if not os.path.isabs(output_path) and not output_path.startswith("."):
|
|
||||||
output_full_path = os.path.join("/app/data", output_path)
|
|
||||||
os.makedirs(os.path.dirname(output_full_path) if os.path.dirname(output_full_path) else ".", exist_ok=True)
|
|
||||||
|
|
||||||
cv2.imwrite(output_full_path, result_img)
|
|
||||||
|
|
||||||
print("="*40)
|
|
||||||
print(f"🎯 扫描完成! 最大差异分: {max_score:.4f}")
|
|
||||||
if found_issue:
|
|
||||||
print(f"⚠️ 警告: 检测到潜在违建区域!")
|
|
||||||
print(f"🖼️ 热力图结果已保存至: {output_full_path}")
|
|
||||||
print("="*40)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="DiffSim Pro 违建检测 (抗视差版)")
|
|
||||||
parser.add_argument("t1", help="基准图路径")
|
|
||||||
parser.add_argument("t2", help="现状图路径")
|
|
||||||
parser.add_argument("out", nargs="?", default="heatmap_diffsim.jpg", help="输出文件名")
|
|
||||||
parser.add_argument("-c", "--crop", type=int, default=224, help="切片大小")
|
|
||||||
parser.add_argument("-s", "--step", type=int, default=0, help="滑动步长")
|
|
||||||
parser.add_argument("-b", "--batch", type=int, default=16, help="批次大小")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
stride = args.step if args.step > 0 else args.crop // 2
|
|
||||||
|
|
||||||
# 初始化模型
|
|
||||||
diffsim_model = DiffSimPro(DEVICE)
|
|
||||||
|
|
||||||
print(f"📂 启动热力图扫描: {args.t1} vs {args.t2}")
|
|
||||||
scan_and_draw(diffsim_model, args.t1, args.t2, args.out, args.crop, stride, args.batch)
|
|
||||||
192
main_dinov2.py
192
main_dinov2.py
@ -1,192 +0,0 @@
|
|||||||
import sys
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import argparse
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# === 配置 ===
|
|
||||||
# 使用 DINOv2 Giant 带寄存器版本 (修复背景伪影,最强版本)
|
|
||||||
# 既然你不在乎显存,我们直接上 1.1B 参数的模型
|
|
||||||
MODEL_NAME = 'dinov2_vitg14_reg'
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
def init_model():
|
|
||||||
print(f"🚀 [系统] 初始化 DINOv2 ({MODEL_NAME})...")
|
|
||||||
if DEVICE == "cuda":
|
|
||||||
print(f"✅ [硬件确认] 正在使用显卡: {torch.cuda.get_device_name(0)}")
|
|
||||||
print(f" (显存状态: {torch.cuda.memory_allocated()/1024**2:.2f}MB 已用)")
|
|
||||||
else:
|
|
||||||
print("❌ [警告] 未检测到显卡,Giant 模型在 CPU 上会非常慢!")
|
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
# force_reload=False 避免每次都下载
|
|
||||||
#model = torch.hub.load('facebookresearch/dinov2', MODEL_NAME)
|
|
||||||
local_path = '/root/.cache/torch/hub/facebookresearch_dinov2_main'
|
|
||||||
|
|
||||||
print(f"📂 [系统] 正在从本地缓存加载代码: {local_path}")
|
|
||||||
if os.path.exists(local_path):
|
|
||||||
model = torch.hub.load(local_path, MODEL_NAME, source='local')
|
|
||||||
else:
|
|
||||||
# 如果万一路径不对,再回退到在线加载(虽然大概率会失败)
|
|
||||||
print("⚠️ 本地缓存未找到,尝试在线加载...")
|
|
||||||
model = torch.hub.load('facebookresearch/dinov2', MODEL_NAME)
|
|
||||||
model.to(DEVICE)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
def preprocess_for_dino(img_cv):
|
|
||||||
"""
|
|
||||||
DINOv2 专用预处理:
|
|
||||||
1. 尺寸必须是 14 的倍数
|
|
||||||
2. 标准 ImageNet 归一化
|
|
||||||
"""
|
|
||||||
h, w = img_cv.shape[:2]
|
|
||||||
|
|
||||||
# 向下取整到 14 的倍数
|
|
||||||
new_h = (h // 14) * 14
|
|
||||||
new_w = (w // 14) * 14
|
|
||||||
|
|
||||||
img_resized = cv2.resize(img_cv, (new_w, new_h))
|
|
||||||
img_pil = Image.fromarray(cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB))
|
|
||||||
|
|
||||||
transform = transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
||||||
])
|
|
||||||
|
|
||||||
return transform(img_pil).unsqueeze(0).to(DEVICE), new_h, new_w
|
|
||||||
|
|
||||||
def scan_and_draw(model, t1_path, t2_path, output_path, threshold):
|
|
||||||
# 1. OpenCV 读取
|
|
||||||
img1_cv = cv2.imread(t1_path)
|
|
||||||
img2_cv = cv2.imread(t2_path)
|
|
||||||
|
|
||||||
if img1_cv is None or img2_cv is None:
|
|
||||||
print("❌ 错误: 无法读取图片")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 强制 Resize 对齐 (以现状图 T2 为准,逻辑保持不变)
|
|
||||||
h_orig, w_orig = img2_cv.shape[:2]
|
|
||||||
img1_cv = cv2.resize(img1_cv, (w_orig, h_orig))
|
|
||||||
|
|
||||||
print(f"🔪 [处理] DINOv2 扫描... 原始尺寸: {w_orig}x{h_orig}")
|
|
||||||
|
|
||||||
# 2. 预处理 (DINO 需要整图输入,不再需要 sliding window 切片循环)
|
|
||||||
# 但为了兼容 DINO 的 Patch 机制,我们需要微调尺寸为 14 的倍数
|
|
||||||
t1_tensor, h_align, w_align = preprocess_for_dino(img1_cv)
|
|
||||||
t2_tensor, _, _ = preprocess_for_dino(img2_cv)
|
|
||||||
|
|
||||||
print(f"🧠 [推理] Giant Model 计算中 (Patch网格: {h_align//14}x{w_align//14})...")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# DINOv2 前向传播 (提取 Patch Token)
|
|
||||||
# feat 形状: [1, N_patches, 1536] (Giant 的维度是 1536)
|
|
||||||
feat1 = model.forward_features(t1_tensor)["x_norm_patchtokens"]
|
|
||||||
feat2 = model.forward_features(t2_tensor)["x_norm_patchtokens"]
|
|
||||||
|
|
||||||
# 计算余弦相似度
|
|
||||||
similarity = F.cosine_similarity(feat1, feat2, dim=-1) # [1, N_patches]
|
|
||||||
|
|
||||||
# 3. 生成热力图数据
|
|
||||||
# reshape 回二维网格
|
|
||||||
grid_h, grid_w = h_align // 14, w_align // 14
|
|
||||||
sim_map = similarity.reshape(grid_h, grid_w).cpu().numpy()
|
|
||||||
|
|
||||||
# 转换逻辑:相似度 -> 差异度 (Diff = 1 - Sim)
|
|
||||||
heatmap_raw = 1.0 - sim_map
|
|
||||||
|
|
||||||
# 将 14x14 的小格子放大回原图尺寸,以便与原图叠加
|
|
||||||
heatmap_avg = cv2.resize(heatmap_raw, (w_orig, h_orig), interpolation=cv2.INTER_CUBIC)
|
|
||||||
|
|
||||||
# 统计信息 (逻辑保持不变)
|
|
||||||
min_v, max_v = heatmap_avg.min(), heatmap_avg.max()
|
|
||||||
print(f"\n📊 [统计] 差异分布: Min={min_v:.4f} | Max={max_v:.4f} | Mean={heatmap_avg.mean():.4f}")
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 🔥 关键:保存原始灰度图 (逻辑保持不变)
|
|
||||||
# ==========================================
|
|
||||||
raw_norm = (heatmap_avg - min_v) / (max_v - min_v + 1e-6)
|
|
||||||
cv2.imwrite("debug_raw_heatmap.png", (raw_norm * 255).astype(np.uint8))
|
|
||||||
print(f"💾 [调试] 原始热力图已保存: debug_raw_heatmap.png")
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 5. 可视化后处理 (逻辑保持不变)
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
# 归一化 (DINO 的差异通常在 0~1 之间,这里做动态拉伸以增强显示)
|
|
||||||
# 如果差异非常小,max_v 可能很小,这里设置一个最小分母防止噪点放大
|
|
||||||
norm_factor = max(max_v, 0.4)
|
|
||||||
heatmap_vis = (heatmap_avg / norm_factor * 255).clip(0, 255).astype(np.uint8)
|
|
||||||
|
|
||||||
# 色彩映射
|
|
||||||
heatmap_color = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET)
|
|
||||||
|
|
||||||
# 图像叠加
|
|
||||||
alpha = 0.4
|
|
||||||
blended_img = cv2.addWeighted(img2_cv, alpha, heatmap_color, 1.0 - alpha, 0)
|
|
||||||
|
|
||||||
# 阈值过滤与画框 (逻辑完全保持不变)
|
|
||||||
_, thresh_img = cv2.threshold(heatmap_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
|
|
||||||
contours, _ = cv2.findContours(thresh_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
|
|
||||||
result_img = blended_img.copy()
|
|
||||||
box_count = 0
|
|
||||||
|
|
||||||
# 既然用了 Giant,我们可以更精细地设定最小面积
|
|
||||||
# 此处保持和你之前代码一致的逻辑,但 DINO 不需要 PatchSize 参数,我们用原图比例
|
|
||||||
min_area = (w_orig * h_orig) * 0.005 # 0.5% 的面积
|
|
||||||
|
|
||||||
for cnt in contours:
|
|
||||||
area = cv2.contourArea(cnt)
|
|
||||||
|
|
||||||
if area > min_area:
|
|
||||||
box_count += 1
|
|
||||||
x, y, bw, bh = cv2.boundingRect(cnt)
|
|
||||||
|
|
||||||
# 画框 (白色粗框 + 红色细框)
|
|
||||||
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4)
|
|
||||||
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2)
|
|
||||||
|
|
||||||
# 显示分数
|
|
||||||
# 计算该区域内的平均差异
|
|
||||||
region_score = heatmap_avg[y:y+bh, x:x+bw].mean()
|
|
||||||
label = f"{region_score:.2f}"
|
|
||||||
|
|
||||||
# 标签背景与文字
|
|
||||||
cv2.rectangle(result_img, (x, y-25), (x+80, y), (0,0,255), -1)
|
|
||||||
cv2.putText(result_img, label, (x+5, y-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
|
|
||||||
|
|
||||||
# 保存最终结果
|
|
||||||
cv2.imwrite(output_path, result_img)
|
|
||||||
|
|
||||||
print("="*40)
|
|
||||||
print(f"🎯 扫描完成! 发现区域: {box_count} 个")
|
|
||||||
print(f"🖼️ 结果已保存至: {output_path}")
|
|
||||||
print("="*40)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="DINOv2 Giant 违建热力图检测 (结构敏感版)")
|
|
||||||
parser.add_argument("t1", help="基准图")
|
|
||||||
parser.add_argument("t2", help="现状图")
|
|
||||||
parser.add_argument("out", nargs="?", default="dino_result.jpg", help="输出图片名")
|
|
||||||
|
|
||||||
# 为了兼容你的习惯,保留了 crop/step 参数接口,虽然 DINO 不需要它们
|
|
||||||
parser.add_argument("-c", "--crop", type=int, default=224, help="(已忽略) DINOv2 全图推理")
|
|
||||||
parser.add_argument("-s", "--step", type=int, default=0, help="(已忽略) DINOv2 全图推理")
|
|
||||||
parser.add_argument("-b", "--batch", type=int, default=16, help="(已忽略) DINOv2 全图推理")
|
|
||||||
|
|
||||||
# 核心参数
|
|
||||||
# DINO 的 Cosine 差异通常比 DreamSim 小,建议阈值给低一点 (如 0.25 - 0.35)
|
|
||||||
parser.add_argument("--thresh", type=float, default=0.30, help="检测阈值 (0.0-1.0)")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# 初始化并运行
|
|
||||||
model = init_model()
|
|
||||||
scan_and_draw(model, args.t1, args.t2, args.out, args.thresh)
|
|
||||||
@ -1,165 +0,0 @@
|
|||||||
import sys
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import argparse
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# === 配置 ===
|
|
||||||
MODEL_NAME = 'dinov2_vitg14_reg'
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
def init_model():
|
|
||||||
print(f"🚀 [系统] 初始化 DINOv2 ({MODEL_NAME})...")
|
|
||||||
if DEVICE == "cuda":
|
|
||||||
print(f"✅ [硬件] 使用设备: {torch.cuda.get_device_name(0)}")
|
|
||||||
|
|
||||||
# === 关键修正:强制使用本地缓存加载 ===
|
|
||||||
local_path = '/root/.cache/torch/hub/facebookresearch_dinov2_main'
|
|
||||||
if os.path.exists(local_path):
|
|
||||||
print(f"📂 [加载] 命中本地缓存: {local_path}")
|
|
||||||
model = torch.hub.load(local_path, MODEL_NAME, source='local')
|
|
||||||
else:
|
|
||||||
print("⚠️ 未找到本地缓存,尝试在线加载...")
|
|
||||||
model = torch.hub.load('facebookresearch/dinov2', MODEL_NAME)
|
|
||||||
|
|
||||||
model.to(DEVICE)
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
def get_transform():
|
|
||||||
return transforms.Compose([
|
|
||||||
transforms.Resize((224, 224)),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
||||||
])
|
|
||||||
|
|
||||||
# === 修正:增加 threshold 参数 ===
|
|
||||||
def scan_and_draw(model, t1_path, t2_path, output_path, patch_size, stride, batch_size, threshold):
|
|
||||||
img1_cv = cv2.imread(t1_path)
|
|
||||||
img2_cv = cv2.imread(t2_path)
|
|
||||||
|
|
||||||
if img1_cv is None or img2_cv is None:
|
|
||||||
print("❌ 错误: 无法读取图片")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 强制对齐
|
|
||||||
h, w = img2_cv.shape[:2]
|
|
||||||
img1_cv = cv2.resize(img1_cv, (w, h))
|
|
||||||
|
|
||||||
print(f"🔪 [切片] DINOv2 扫描... 尺寸: {w}x{h}")
|
|
||||||
print(f" - 参数: Crop={patch_size}, Step={stride}, Thresh={threshold}")
|
|
||||||
|
|
||||||
# 准备切片
|
|
||||||
patches1_pil = []
|
|
||||||
patches2_pil = []
|
|
||||||
coords = []
|
|
||||||
|
|
||||||
for y in range(0, h - patch_size + 1, stride):
|
|
||||||
for x in range(0, w - patch_size + 1, stride):
|
|
||||||
crop1 = img1_cv[y:y+patch_size, x:x+patch_size]
|
|
||||||
crop2 = img2_cv[y:y+patch_size, x:x+patch_size]
|
|
||||||
|
|
||||||
p1 = Image.fromarray(cv2.cvtColor(crop1, cv2.COLOR_BGR2RGB))
|
|
||||||
p2 = Image.fromarray(cv2.cvtColor(crop2, cv2.COLOR_BGR2RGB))
|
|
||||||
|
|
||||||
patches1_pil.append(p1)
|
|
||||||
patches2_pil.append(p2)
|
|
||||||
coords.append((x, y))
|
|
||||||
|
|
||||||
if not patches1_pil:
|
|
||||||
print("⚠️ 图片太小,无法切片")
|
|
||||||
return
|
|
||||||
|
|
||||||
total_patches = len(patches1_pil)
|
|
||||||
print(f"🧠 [推理] 共 {total_patches} 个切片...")
|
|
||||||
|
|
||||||
all_distances = []
|
|
||||||
transform = get_transform()
|
|
||||||
|
|
||||||
for i in tqdm(range(0, total_patches, batch_size), unit="batch"):
|
|
||||||
batch_p1_list = [transform(p) for p in patches1_pil[i : i + batch_size]]
|
|
||||||
batch_p2_list = [transform(p) for p in patches2_pil[i : i + batch_size]]
|
|
||||||
|
|
||||||
if not batch_p1_list: break
|
|
||||||
|
|
||||||
batch_p1 = torch.stack(batch_p1_list).to(DEVICE)
|
|
||||||
batch_p2 = torch.stack(batch_p2_list).to(DEVICE)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
feat1 = model.forward_features(batch_p1)["x_norm_clstoken"]
|
|
||||||
feat2 = model.forward_features(batch_p2)["x_norm_clstoken"]
|
|
||||||
sim_batch = F.cosine_similarity(feat1, feat2, dim=-1)
|
|
||||||
dist_batch = 1.0 - sim_batch
|
|
||||||
all_distances.append(dist_batch.cpu())
|
|
||||||
|
|
||||||
distances = torch.cat(all_distances)
|
|
||||||
|
|
||||||
# 重建热力图
|
|
||||||
heatmap = np.zeros((h, w), dtype=np.float32)
|
|
||||||
count_map = np.zeros((h, w), dtype=np.float32)
|
|
||||||
max_score = distances.max().item()
|
|
||||||
|
|
||||||
for idx, score in enumerate(distances):
|
|
||||||
val = score.item()
|
|
||||||
x, y = coords[idx]
|
|
||||||
heatmap[y:y+patch_size, x:x+patch_size] += val
|
|
||||||
count_map[y:y+patch_size, x:x+patch_size] += 1
|
|
||||||
|
|
||||||
count_map[count_map == 0] = 1
|
|
||||||
heatmap_avg = heatmap / count_map
|
|
||||||
|
|
||||||
# 可视化
|
|
||||||
norm_denom = max(max_score, 0.4)
|
|
||||||
heatmap_vis = (heatmap_avg / norm_denom * 255).clip(0, 255).astype(np.uint8)
|
|
||||||
heatmap_color = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET)
|
|
||||||
blended_img = cv2.addWeighted(img2_cv, 0.4, heatmap_color, 0.6, 0)
|
|
||||||
|
|
||||||
# === 使用传入的 threshold 参数 ===
|
|
||||||
_, thresh_img = cv2.threshold(heatmap_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
|
|
||||||
contours, _ = cv2.findContours(thresh_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
|
|
||||||
result_img = blended_img.copy()
|
|
||||||
box_count = 0
|
|
||||||
|
|
||||||
for cnt in contours:
|
|
||||||
area = cv2.contourArea(cnt)
|
|
||||||
if area > (patch_size * patch_size) * 0.03:
|
|
||||||
box_count += 1
|
|
||||||
x, y, bw, bh = cv2.boundingRect(cnt)
|
|
||||||
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4)
|
|
||||||
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2)
|
|
||||||
|
|
||||||
score_val = heatmap_avg[y:y+bh, x:x+bw].mean()
|
|
||||||
cv2.rectangle(result_img, (x, y-25), (x+130, y), (0,0,255), -1)
|
|
||||||
cv2.putText(result_img, f"Diff: {score_val:.2f}", (x+5, y-7),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
|
|
||||||
|
|
||||||
cv2.imwrite(output_path, result_img)
|
|
||||||
print("="*40)
|
|
||||||
print(f"🎯 检测完成! 最大差异: {max_score:.4f} | 发现区域: {box_count}")
|
|
||||||
print(f"🖼️ 结果: {output_path}")
|
|
||||||
print("="*40)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="DINOv2 Giant 切片版")
|
|
||||||
parser.add_argument("t1", help="基准图")
|
|
||||||
parser.add_argument("t2", help="现状图")
|
|
||||||
parser.add_argument("out", nargs="?", default="dino_sliced_result.jpg")
|
|
||||||
parser.add_argument("-c", "--crop", type=int, default=224, help="切片大小")
|
|
||||||
parser.add_argument("-s", "--step", type=int, default=0, help="步长")
|
|
||||||
parser.add_argument("-b", "--batch", type=int, default=8, help="批次")
|
|
||||||
|
|
||||||
# === 修正:添加 --thresh 参数接口 ===
|
|
||||||
parser.add_argument("--thresh", type=float, default=0.30, help="检测阈值")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
stride = args.step if args.step > 0 else args.crop // 2
|
|
||||||
|
|
||||||
model = init_model()
|
|
||||||
# 传入 args.thresh
|
|
||||||
scan_and_draw(model, args.t1, args.t2, args.out, args.crop, stride, args.batch, args.thresh)
|
|
||||||
185
main_dreamsim.py
185
main_dreamsim.py
@ -1,185 +0,0 @@
|
|||||||
import sys
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import argparse
|
|
||||||
from dreamsim import dreamsim
|
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# === 配置 ===
|
|
||||||
# DreamSim 官方推荐 ensemble 模式效果最好,虽然慢一点但更准
|
|
||||||
MODEL_TYPE = "ensemble"
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
def init_model():
|
|
||||||
print(f"🚀 [系统] 初始化 DreamSim ({MODEL_TYPE})...")
|
|
||||||
if DEVICE == "cuda":
|
|
||||||
print(f"✅ [硬件确认] 正在使用显卡: {torch.cuda.get_device_name(0)}")
|
|
||||||
print(f" (显存状态: {torch.cuda.memory_allocated()/1024**2:.2f}MB 已用)")
|
|
||||||
else:
|
|
||||||
print("❌ [警告] 未检测到显卡,正在使用 CPU 慢速运行!")
|
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
model, preprocess = dreamsim(pretrained=True, dreamsim_type=MODEL_TYPE, device=DEVICE)
|
|
||||||
model.to(DEVICE)
|
|
||||||
|
|
||||||
return model, preprocess
|
|
||||||
|
|
||||||
def scan_and_draw(model, preprocess, t1_path, t2_path, output_path, patch_size, stride, batch_size, threshold):
|
|
||||||
# 1. OpenCV 读取
|
|
||||||
img1_cv = cv2.imread(t1_path)
|
|
||||||
img2_cv = cv2.imread(t2_path)
|
|
||||||
|
|
||||||
if img1_cv is None or img2_cv is None:
|
|
||||||
print("❌ 错误: 无法读取图片")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 强制 Resize 对齐 (以现状图 T2 为准)
|
|
||||||
h, w = img2_cv.shape[:2]
|
|
||||||
img1_cv = cv2.resize(img1_cv, (w, h))
|
|
||||||
|
|
||||||
print(f"🔪 [切片] DreamSim 扫描... 尺寸: {w}x{h}")
|
|
||||||
print(f" - 参数: Crop={patch_size}, Step={stride}, Batch={batch_size}, Thresh={threshold}")
|
|
||||||
|
|
||||||
# 2. 准备滑动窗口
|
|
||||||
patches1 = []
|
|
||||||
patches2 = []
|
|
||||||
coords = []
|
|
||||||
|
|
||||||
for y in range(0, h - patch_size + 1, stride):
|
|
||||||
for x in range(0, w - patch_size + 1, stride):
|
|
||||||
crop1 = img1_cv[y:y+patch_size, x:x+patch_size]
|
|
||||||
crop2 = img2_cv[y:y+patch_size, x:x+patch_size]
|
|
||||||
|
|
||||||
# DreamSim 预处理
|
|
||||||
p1 = preprocess(Image.fromarray(cv2.cvtColor(crop1, cv2.COLOR_BGR2RGB)))
|
|
||||||
p2 = preprocess(Image.fromarray(cv2.cvtColor(crop2, cv2.COLOR_BGR2RGB)))
|
|
||||||
|
|
||||||
# 修正维度: preprocess 可能返回 [1, 3, 224, 224],我们需要 [3, 224, 224]
|
|
||||||
if p1.ndim == 4: p1 = p1.squeeze(0)
|
|
||||||
if p2.ndim == 4: p2 = p2.squeeze(0)
|
|
||||||
|
|
||||||
patches1.append(p1)
|
|
||||||
patches2.append(p2)
|
|
||||||
coords.append((x, y))
|
|
||||||
|
|
||||||
if not patches1:
|
|
||||||
print("⚠️ 图片太小,无法切片")
|
|
||||||
return
|
|
||||||
|
|
||||||
total_patches = len(patches1)
|
|
||||||
print(f"🧠 [推理] 共 {total_patches} 个切片,开始计算...")
|
|
||||||
|
|
||||||
all_distances = []
|
|
||||||
|
|
||||||
# 3. 批量推理 (使用 tqdm 显示进度)
|
|
||||||
for i in tqdm(range(0, total_patches, batch_size), unit="batch"):
|
|
||||||
batch_p1 = torch.stack(patches1[i : i + batch_size]).to(DEVICE)
|
|
||||||
batch_p2 = torch.stack(patches2[i : i + batch_size]).to(DEVICE)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# DreamSim 前向传播
|
|
||||||
dist_batch = model(batch_p1, batch_p2)
|
|
||||||
all_distances.append(dist_batch.cpu())
|
|
||||||
|
|
||||||
distances = torch.cat(all_distances)
|
|
||||||
|
|
||||||
# 4. 生成热力图数据
|
|
||||||
heatmap = np.zeros((h, w), dtype=np.float32)
|
|
||||||
count_map = np.zeros((h, w), dtype=np.float32)
|
|
||||||
|
|
||||||
# 统计信息
|
|
||||||
min_v, max_v = distances.min().item(), distances.max().item()
|
|
||||||
print(f"\n📊 [统计] 分数分布: Min={min_v:.4f} | Max={max_v:.4f} | Mean={distances.mean().item():.4f}")
|
|
||||||
|
|
||||||
for idx, score in enumerate(distances):
|
|
||||||
val = score.item()
|
|
||||||
x, y = coords[idx]
|
|
||||||
|
|
||||||
heatmap[y:y+patch_size, x:x+patch_size] += val
|
|
||||||
count_map[y:y+patch_size, x:x+patch_size] += 1
|
|
||||||
|
|
||||||
# 平均化重叠区域
|
|
||||||
count_map[count_map == 0] = 1
|
|
||||||
heatmap_avg = heatmap / count_map
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 🔥 关键:保存原始灰度图 (供前端调试)
|
|
||||||
# ==========================================
|
|
||||||
raw_norm = (heatmap_avg - min_v) / (max_v - min_v + 1e-6)
|
|
||||||
cv2.imwrite("debug_raw_heatmap.png", (raw_norm * 255).astype(np.uint8))
|
|
||||||
print(f"💾 [调试] 原始热力图已保存: debug_raw_heatmap.png")
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 5. 可视化后处理
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
# 归一化 (使用 max_v 或固定因子)
|
|
||||||
norm_factor = max(max_v, 0.1)
|
|
||||||
heatmap_vis = (heatmap_avg / norm_factor * 255).clip(0, 255).astype(np.uint8)
|
|
||||||
|
|
||||||
# 色彩映射
|
|
||||||
heatmap_color = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET)
|
|
||||||
|
|
||||||
# 图像叠加
|
|
||||||
alpha = 0.4
|
|
||||||
blended_img = cv2.addWeighted(img2_cv, alpha, heatmap_color, 1.0 - alpha, 0)
|
|
||||||
|
|
||||||
# 阈值过滤与画框
|
|
||||||
# 使用传入的 threshold 参数
|
|
||||||
_, thresh_img = cv2.threshold(heatmap_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
|
|
||||||
contours, _ = cv2.findContours(thresh_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
|
|
||||||
result_img = blended_img.copy()
|
|
||||||
box_count = 0
|
|
||||||
|
|
||||||
for cnt in contours:
|
|
||||||
area = cv2.contourArea(cnt)
|
|
||||||
# 过滤过小的区域 (3% 的切片面积)
|
|
||||||
min_area = (patch_size * patch_size) * 0.03
|
|
||||||
|
|
||||||
if area > min_area:
|
|
||||||
box_count += 1
|
|
||||||
x, y, bw, bh = cv2.boundingRect(cnt)
|
|
||||||
|
|
||||||
# 画框
|
|
||||||
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4)
|
|
||||||
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2)
|
|
||||||
|
|
||||||
# 显示分数
|
|
||||||
label = f"{heatmap_avg[y:y+bh, x:x+bw].mean():.2f}"
|
|
||||||
cv2.rectangle(result_img, (x, y-25), (x+80, y), (0,0,255), -1)
|
|
||||||
cv2.putText(result_img, label, (x+5, y-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
|
|
||||||
|
|
||||||
# 保存最终结果
|
|
||||||
cv2.imwrite(output_path, result_img)
|
|
||||||
|
|
||||||
print("="*40)
|
|
||||||
print(f"🎯 扫描完成! 发现区域: {box_count} 个")
|
|
||||||
print(f"🖼️ 结果已保存至: {output_path}")
|
|
||||||
print("="*40)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="DreamSim 违建热力图检测 (标准化版)")
|
|
||||||
parser.add_argument("t1", help="基准图")
|
|
||||||
parser.add_argument("t2", help="现状图")
|
|
||||||
parser.add_argument("out", nargs="?", default="heatmap_result.jpg", help="输出图片名")
|
|
||||||
|
|
||||||
# 扫描参数
|
|
||||||
parser.add_argument("-c", "--crop", type=int, default=224, help="切片大小")
|
|
||||||
parser.add_argument("-s", "--step", type=int, default=0, help="步长")
|
|
||||||
parser.add_argument("-b", "--batch", type=int, default=16, help="批次")
|
|
||||||
|
|
||||||
# 核心参数
|
|
||||||
parser.add_argument("--thresh", type=float, default=0.30, help="检测阈值 (0.0-1.0)")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# 自动计算步长
|
|
||||||
stride = args.step if args.step > 0 else args.crop // 2
|
|
||||||
|
|
||||||
# 初始化并运行
|
|
||||||
model, preprocess = init_model()
|
|
||||||
scan_and_draw(model, preprocess, args.t1, args.t2, args.out, args.crop, stride, args.batch, args.thresh)
|
|
||||||
293
main_finally.py
293
main_finally.py
@ -1,293 +0,0 @@
|
|||||||
import os
|
|
||||||
# 🚀 强制使用国内镜像
|
|
||||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import argparse
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
from diffusers import StableDiffusionPipeline
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# PART 1: DiffSim 官方核心逻辑还原
|
|
||||||
# 基于: https://github.com/showlab/DiffSim/blob/main/diffsim/models/diffsim.py
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
class DiffSim(nn.Module):
|
|
||||||
def __init__(self, model_id="Manojb/stable-diffusion-2-1-base", device="cuda"):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
print(f"🚀 [Core] Loading Official DiffSim Logic (Backbone: {model_id})...")
|
|
||||||
|
|
||||||
# 1. 加载 SD 模型
|
|
||||||
try:
|
|
||||||
self.pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 模型加载失败,尝试加载默认 ID... Error: {e}")
|
|
||||||
self.pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16).to(device)
|
|
||||||
|
|
||||||
self.pipe.set_progress_bar_config(disable=True)
|
|
||||||
|
|
||||||
# 2. 冻结参数 (Freeze)
|
|
||||||
self.pipe.vae.requires_grad_(False)
|
|
||||||
self.pipe.unet.requires_grad_(False)
|
|
||||||
self.pipe.text_encoder.requires_grad_(False)
|
|
||||||
|
|
||||||
# 3. 预计算空文本 Embedding (Unconditional Guidance)
|
|
||||||
with torch.no_grad():
|
|
||||||
prompt = ""
|
|
||||||
text_input = self.pipe.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
|
|
||||||
self.empty_embeds = self.pipe.text_encoder(text_input.input_ids.to(device))[0]
|
|
||||||
|
|
||||||
self.features = {}
|
|
||||||
self._register_official_hooks()
|
|
||||||
|
|
||||||
def _register_official_hooks(self):
|
|
||||||
"""
|
|
||||||
DiffSim 官方策略: 提取 up_blocks.1 (Semantic) 和 up_blocks.2 (Structure)
|
|
||||||
"""
|
|
||||||
self.target_layers = {
|
|
||||||
"up_blocks.1.resnets.1": "feat_semantic", # 语义层
|
|
||||||
"up_blocks.2.resnets.1": "feat_structure" # 结构层
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"🔧 [Hook] Registered Layers: {list(self.target_layers.values())}")
|
|
||||||
|
|
||||||
for name, layer in self.pipe.unet.named_modules():
|
|
||||||
if name in self.target_layers:
|
|
||||||
alias = self.target_layers[name]
|
|
||||||
layer.register_forward_hook(self._get_hook(alias))
|
|
||||||
|
|
||||||
def _get_hook(self, name):
|
|
||||||
def hook(model, input, output):
|
|
||||||
self.features[name] = output
|
|
||||||
return hook
|
|
||||||
|
|
||||||
def extract_features(self, images):
|
|
||||||
# VAE Encoding
|
|
||||||
latents = self.pipe.vae.encode(images).latent_dist.sample() * self.pipe.vae.config.scaling_factor
|
|
||||||
|
|
||||||
# UNet Inference
|
|
||||||
t = torch.zeros(latents.shape[0], device=self.device, dtype=torch.long)
|
|
||||||
encoder_hidden_states = self.empty_embeds.expand(latents.shape[0], -1, -1)
|
|
||||||
|
|
||||||
self.features = {} # Reset buffer
|
|
||||||
self.pipe.unet(latents, t, encoder_hidden_states=encoder_hidden_states)
|
|
||||||
|
|
||||||
return {k: v.clone() for k, v in self.features.items()}
|
|
||||||
|
|
||||||
def calculate_robust_similarity(self, feat_a, feat_b, kernel_size=3):
|
|
||||||
"""
|
|
||||||
官方核心算法: Spatially Robust Similarity
|
|
||||||
公式: S(p) = max_{q in Neighbor(p)} cos(F1(p), F2(q))
|
|
||||||
"""
|
|
||||||
# Normalize vectors
|
|
||||||
feat_a = F.normalize(feat_a, dim=1)
|
|
||||||
feat_b = F.normalize(feat_b, dim=1)
|
|
||||||
|
|
||||||
if kernel_size <= 1:
|
|
||||||
# 严格对齐 (Pixel-wise Cosine Similarity)
|
|
||||||
return (feat_a * feat_b).sum(dim=1)
|
|
||||||
|
|
||||||
# 邻域搜索 (Sliding Window Matching)
|
|
||||||
b, c, h, w = feat_b.shape
|
|
||||||
padding = kernel_size // 2
|
|
||||||
|
|
||||||
# Unfold feature B to find neighbors
|
|
||||||
feat_b_unfolded = F.unfold(feat_b, kernel_size=kernel_size, padding=padding)
|
|
||||||
feat_b_unfolded = feat_b_unfolded.view(b, c, kernel_size*kernel_size, h, w)
|
|
||||||
|
|
||||||
# Calculate cosine sim between A and all neighbors of B
|
|
||||||
# Shape: [B, K*K, H, W]
|
|
||||||
sim_map = (feat_a.unsqueeze(2) * feat_b_unfolded).sum(dim=1)
|
|
||||||
|
|
||||||
# Take the best match (Max Pooling logic)
|
|
||||||
best_sim, _ = sim_map.max(dim=1)
|
|
||||||
|
|
||||||
return best_sim
|
|
||||||
|
|
||||||
def forward(self, batch_t1, batch_t2, w_struct, w_sem, kernel_size):
|
|
||||||
f1 = self.extract_features(batch_t1)
|
|
||||||
f2 = self.extract_features(batch_t2)
|
|
||||||
|
|
||||||
total_dist = 0
|
|
||||||
|
|
||||||
# Semantic Distance
|
|
||||||
if w_sem > 0 and "feat_semantic" in f1:
|
|
||||||
sim = self.calculate_robust_similarity(f1["feat_semantic"], f2["feat_semantic"], kernel_size)
|
|
||||||
dist = 1.0 - sim
|
|
||||||
total_dist += dist.mean(dim=[1, 2]) * w_sem
|
|
||||||
|
|
||||||
# Structure Distance
|
|
||||||
if w_struct > 0 and "feat_structure" in f1:
|
|
||||||
sim = self.calculate_robust_similarity(f1["feat_structure"], f2["feat_structure"], kernel_size)
|
|
||||||
dist = 1.0 - sim
|
|
||||||
total_dist += dist.mean(dim=[1, 2]) * w_struct
|
|
||||||
|
|
||||||
return total_dist
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# PART 2: 增强后处理逻辑 (Post-Processing)
|
|
||||||
# 这一部分不在 DiffSim 官方库中,是为了实际工程落地增加的去噪模块
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def engineering_post_process(heatmap_full, img_bg, args, patch_size):
|
|
||||||
h, w = heatmap_full.shape
|
|
||||||
|
|
||||||
# 1. 动态范围归一化
|
|
||||||
# 避免最大值过小(纯净背景)时,强制放大噪点
|
|
||||||
local_max = heatmap_full.max()
|
|
||||||
safe_max = max(local_max, 0.25) # 设定一个基准置信度,低于此值不拉伸
|
|
||||||
|
|
||||||
heatmap_norm = (heatmap_full / safe_max * 255).clip(0, 255).astype(np.uint8)
|
|
||||||
|
|
||||||
# 保存原始数据供调试
|
|
||||||
cv2.imwrite("debug_raw_heatmap.png", heatmap_norm)
|
|
||||||
|
|
||||||
# 2. 高斯滤波 (去散斑)
|
|
||||||
heatmap_blur = cv2.GaussianBlur(heatmap_norm, (5, 5), 0)
|
|
||||||
|
|
||||||
# 3. 阈值截断 (Hard Thresholding)
|
|
||||||
_, binary = cv2.threshold(heatmap_blur, int(255 * args.thresh), 255, cv2.THRESH_BINARY)
|
|
||||||
|
|
||||||
# 4. 形态学闭运算 (Merging)
|
|
||||||
# 将破碎的邻近区域融合为一个整体
|
|
||||||
kernel_morph = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7))
|
|
||||||
binary_closed = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel_morph)
|
|
||||||
|
|
||||||
# 5. 可视化绘制
|
|
||||||
heatmap_color = cv2.applyColorMap(heatmap_norm, cv2.COLORMAP_JET)
|
|
||||||
result_img = cv2.addWeighted(img_bg, 0.4, heatmap_color, 0.6, 0)
|
|
||||||
|
|
||||||
contours, _ = cv2.findContours(binary_closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
|
|
||||||
box_count = 0
|
|
||||||
# 面积过滤: 忽略小于切片面积 3% 的噪点
|
|
||||||
min_area = (patch_size ** 2) * 0.03
|
|
||||||
|
|
||||||
for cnt in contours:
|
|
||||||
area = cv2.contourArea(cnt)
|
|
||||||
if area > min_area:
|
|
||||||
box_count += 1
|
|
||||||
x, y, bw, bh = cv2.boundingRect(cnt)
|
|
||||||
|
|
||||||
# 绘制:白边 + 红框
|
|
||||||
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4)
|
|
||||||
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2)
|
|
||||||
|
|
||||||
# 分数标签
|
|
||||||
score_val = heatmap_full[y:y+bh, x:x+bw].mean()
|
|
||||||
label = f"{score_val:.2f}"
|
|
||||||
|
|
||||||
# 标签背景
|
|
||||||
cv2.rectangle(result_img, (x, y-22), (x+55, y), (0,0,255), -1)
|
|
||||||
cv2.putText(result_img, label, (x+5, y-6), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)
|
|
||||||
|
|
||||||
return result_img, box_count
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# PART 3: 执行脚本
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="DiffSim Local Implementation")
|
|
||||||
parser.add_argument("t1", help="Reference Image")
|
|
||||||
parser.add_argument("t2", help="Query Image")
|
|
||||||
parser.add_argument("out", default="result.jpg")
|
|
||||||
|
|
||||||
# DiffSim 官方推荐参数
|
|
||||||
parser.add_argument("--w_struct", type=float, default=0.4)
|
|
||||||
parser.add_argument("--w_sem", type=float, default=0.6)
|
|
||||||
parser.add_argument("--kernel", type=int, default=3, help="Robust Kernel Size (1, 3, 5)")
|
|
||||||
|
|
||||||
# 工程化参数
|
|
||||||
parser.add_argument("--gamma", type=float, default=1.0)
|
|
||||||
parser.add_argument("--thresh", type=float, default=0.3)
|
|
||||||
parser.add_argument("-c", "--crop", type=int, default=224)
|
|
||||||
parser.add_argument("-b", "--batch", type=int, default=16)
|
|
||||||
parser.add_argument("--model", default="Manojb/stable-diffusion-2-1-base")
|
|
||||||
|
|
||||||
# 兼容性冗余参数
|
|
||||||
parser.add_argument("--step", type=int, default=0)
|
|
||||||
parser.add_argument("--w_tex", type=float, default=0.0)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# 1. Image IO
|
|
||||||
t1 = cv2.imread(args.t1)
|
|
||||||
t2 = cv2.imread(args.t2)
|
|
||||||
if t1 is None or t2 is None:
|
|
||||||
print("❌ Error reading images.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Resize to match T2
|
|
||||||
h, w = t2.shape[:2]
|
|
||||||
t1 = cv2.resize(t1, (w, h))
|
|
||||||
|
|
||||||
# 2. Preprocessing
|
|
||||||
transform = transforms.Compose([
|
|
||||||
transforms.Resize((224, 224)),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.5], [0.5])
|
|
||||||
])
|
|
||||||
|
|
||||||
patches1, patches2, coords = [], [], []
|
|
||||||
stride = args.crop // 2 # 50% Overlap
|
|
||||||
|
|
||||||
print(f"🔪 Slicing images ({w}x{h}) with stride {stride}...")
|
|
||||||
for y in range(0, h - args.crop + 1, stride):
|
|
||||||
for x in range(0, w - args.crop + 1, stride):
|
|
||||||
c1 = t1[y:y+args.crop, x:x+args.crop]
|
|
||||||
c2 = t2[y:y+args.crop, x:x+args.crop]
|
|
||||||
|
|
||||||
p1 = transform(Image.fromarray(cv2.cvtColor(c1, cv2.COLOR_BGR2RGB)))
|
|
||||||
p2 = transform(Image.fromarray(cv2.cvtColor(c2, cv2.COLOR_BGR2RGB)))
|
|
||||||
|
|
||||||
patches1.append(p1); patches2.append(p2); coords.append((x, y))
|
|
||||||
|
|
||||||
if not patches1: return
|
|
||||||
|
|
||||||
# 3. Model Inference
|
|
||||||
model = DiffSim(args.model)
|
|
||||||
scores = []
|
|
||||||
|
|
||||||
print(f"🧠 Running DiffSim Inference on {len(patches1)} patches...")
|
|
||||||
with torch.no_grad():
|
|
||||||
for i in tqdm(range(0, len(patches1), args.batch)):
|
|
||||||
b1 = torch.stack(patches1[i:i+args.batch]).to("cuda", dtype=torch.float16)
|
|
||||||
b2 = torch.stack(patches2[i:i+args.batch]).to("cuda", dtype=torch.float16)
|
|
||||||
|
|
||||||
batch_dist = model(b1, b2, args.w_struct, args.w_sem, args.kernel)
|
|
||||||
scores.append(batch_dist.cpu())
|
|
||||||
|
|
||||||
all_scores = torch.cat(scores).float().numpy()
|
|
||||||
|
|
||||||
# 4. Reconstruct Heatmap
|
|
||||||
heatmap_full = np.zeros((h, w), dtype=np.float32)
|
|
||||||
count_map = np.zeros((h, w), dtype=np.float32) + 1e-6
|
|
||||||
|
|
||||||
# Apply Gamma *before* merging
|
|
||||||
if args.gamma != 1.0:
|
|
||||||
all_scores = np.power(all_scores, args.gamma)
|
|
||||||
|
|
||||||
for idx, score in enumerate(all_scores):
|
|
||||||
x, y = coords[idx]
|
|
||||||
heatmap_full[y:y+args.crop, x:x+args.crop] += score
|
|
||||||
count_map[y:y+args.crop, x:x+args.crop] += 1
|
|
||||||
|
|
||||||
heatmap_avg = heatmap_full / count_map
|
|
||||||
|
|
||||||
# 5. Post-Processing & Draw
|
|
||||||
print("🎨 Post-processing results...")
|
|
||||||
final_img, count = engineering_post_process(heatmap_avg, t2, args, args.crop)
|
|
||||||
|
|
||||||
cv2.imwrite(args.out, final_img)
|
|
||||||
print(f"✅ Done! Found {count} regions. Saved to {args.out}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
204
main_plus.py
204
main_plus.py
@ -1,204 +0,0 @@
|
|||||||
import os
|
|
||||||
# 🔥 强制设置 HF 镜像
|
|
||||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import argparse
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
from diffusers import StableDiffusionPipeline
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# === 配置 ===
|
|
||||||
MODEL_ID = "Manojb/stable-diffusion-2-1-base"
|
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
THRESHOLD = 0.40 # ⬆️ 稍微调高阈值,进一步过滤误报
|
|
||||||
IMG_RESIZE = 224
|
|
||||||
|
|
||||||
class DiffSimSemantic(nn.Module):
|
|
||||||
def __init__(self, device):
|
|
||||||
super().__init__()
|
|
||||||
print(f"🚀 [系统] 初始化 DiffSim (语义增强版)...")
|
|
||||||
|
|
||||||
self.pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16).to(device)
|
|
||||||
self.pipe.set_progress_bar_config(disable=True)
|
|
||||||
|
|
||||||
# 冻结参数
|
|
||||||
self.pipe.vae.requires_grad_(False)
|
|
||||||
self.pipe.unet.requires_grad_(False)
|
|
||||||
self.pipe.text_encoder.requires_grad_(False)
|
|
||||||
|
|
||||||
# 预计算空文本 Embedding
|
|
||||||
with torch.no_grad():
|
|
||||||
prompt = ""
|
|
||||||
text_inputs = self.pipe.tokenizer(
|
|
||||||
prompt,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=self.pipe.tokenizer.model_max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
text_input_ids = text_inputs.input_ids.to(device)
|
|
||||||
self.empty_text_embeds = self.pipe.text_encoder(text_input_ids)[0]
|
|
||||||
|
|
||||||
self.features = {}
|
|
||||||
|
|
||||||
# 🔥 修改 Hooks:只抓取深层特征,忽略浅层纹理
|
|
||||||
# up_blocks.1 (纹理层) -> ❌ 移除,太敏感,容易误报
|
|
||||||
# up_blocks.2 (结构层) -> ✅ 保留,判断形状
|
|
||||||
# up_blocks.3 (语义层) -> ✅ 核心,判断物体类别
|
|
||||||
for name, layer in self.pipe.unet.named_modules():
|
|
||||||
# 我们不再 Hook up_blocks.1,因为它对光照和纹理太敏感
|
|
||||||
if "up_blocks.2" in name and name.endswith("resnets.2"):
|
|
||||||
layer.register_forward_hook(self.get_hook("feat_structure"))
|
|
||||||
elif "up_blocks.3" in name and name.endswith("resnets.2"):
|
|
||||||
layer.register_forward_hook(self.get_hook("feat_semantic"))
|
|
||||||
|
|
||||||
def get_hook(self, name):
|
|
||||||
def hook(model, input, output):
|
|
||||||
self.features[name] = output
|
|
||||||
return hook
|
|
||||||
|
|
||||||
def extract_features(self, images):
|
|
||||||
latents = self.pipe.vae.encode(images).latent_dist.sample() * self.pipe.vae.config.scaling_factor
|
|
||||||
batch_size = latents.shape[0]
|
|
||||||
t = torch.zeros(batch_size, device=DEVICE, dtype=torch.long)
|
|
||||||
encoder_hidden_states = self.empty_text_embeds.expand(batch_size, -1, -1)
|
|
||||||
self.pipe.unet(latents, t, encoder_hidden_states=encoder_hidden_states)
|
|
||||||
return {k: v.clone() for k, v in self.features.items()}
|
|
||||||
|
|
||||||
def robust_similarity(self, f1, f2, kernel_size=5):
|
|
||||||
"""
|
|
||||||
抗视差匹配:
|
|
||||||
🔥 将 kernel_size 默认值提升到 5
|
|
||||||
允许更大的几何错位(应对30m高度的视差)
|
|
||||||
"""
|
|
||||||
f1 = F.normalize(f1, dim=1)
|
|
||||||
f2 = F.normalize(f2, dim=1)
|
|
||||||
|
|
||||||
padding = kernel_size // 2
|
|
||||||
b, c, h, w = f2.shape
|
|
||||||
|
|
||||||
f2_unfolded = F.unfold(f2, kernel_size=kernel_size, padding=padding)
|
|
||||||
f2_unfolded = f2_unfolded.view(b, c, kernel_size*kernel_size, h, w)
|
|
||||||
|
|
||||||
sim_map = (f1.unsqueeze(2) * f2_unfolded).sum(dim=1)
|
|
||||||
max_sim, _ = sim_map.max(dim=1)
|
|
||||||
|
|
||||||
return max_sim
|
|
||||||
|
|
||||||
def compute_batch_distance(self, batch_p1, batch_p2):
|
|
||||||
feat_a = self.extract_features(batch_p1)
|
|
||||||
feat_b = self.extract_features(batch_p2)
|
|
||||||
|
|
||||||
total_score = 0
|
|
||||||
|
|
||||||
# 🔥 调整后的权重策略:纯粹关注结构和语义
|
|
||||||
# 0.0 -> 纹理 (彻底忽略颜色深浅、阴影)
|
|
||||||
# 0.4 -> 结构 (feat_structure): 关注形状变化
|
|
||||||
# 0.6 -> 语义 (feat_semantic): 关注物体存在性 (最像 DreamSim 的部分)
|
|
||||||
weights = {"feat_structure": 0.4, "feat_semantic": 0.6}
|
|
||||||
|
|
||||||
for name, w in weights.items():
|
|
||||||
fa, fb = feat_a[name].float(), feat_b[name].float()
|
|
||||||
|
|
||||||
# 对所有层都启用抗视差匹配,增加鲁棒性
|
|
||||||
# kernel_size=5 能容忍更大的像素位移
|
|
||||||
sim_map = self.robust_similarity(fa, fb, kernel_size=5)
|
|
||||||
dist = 1 - sim_map.mean(dim=[1, 2])
|
|
||||||
|
|
||||||
total_score += dist * w
|
|
||||||
|
|
||||||
return total_score
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 辅助函数 (保持不变)
|
|
||||||
# ==========================================
|
|
||||||
def get_transforms():
|
|
||||||
return transforms.Compose([
|
|
||||||
transforms.Resize((IMG_RESIZE, IMG_RESIZE)),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.5], [0.5])
|
|
||||||
])
|
|
||||||
|
|
||||||
def scan_and_draw(model, t1_path, t2_path, output_path, patch_size, stride, batch_size):
|
|
||||||
img1_cv = cv2.imread(t1_path)
|
|
||||||
img2_cv = cv2.imread(t2_path)
|
|
||||||
if img1_cv is None or img2_cv is None: return
|
|
||||||
|
|
||||||
h, w = img2_cv.shape[:2]
|
|
||||||
img1_cv = cv2.resize(img1_cv, (w, h))
|
|
||||||
preprocess = get_transforms()
|
|
||||||
|
|
||||||
print(f"🔪 [切片] 开始扫描... 尺寸: {w}x{h}, 忽略纹理细节,专注语义差异")
|
|
||||||
patches1, patches2, coords = [], [], []
|
|
||||||
|
|
||||||
for y in range(0, h - patch_size + 1, stride):
|
|
||||||
for x in range(0, w - patch_size + 1, stride):
|
|
||||||
crop1 = img1_cv[y:y+patch_size, x:x+patch_size]
|
|
||||||
crop2 = img2_cv[y:y+patch_size, x:x+patch_size]
|
|
||||||
p1 = preprocess(Image.fromarray(cv2.cvtColor(crop1, cv2.COLOR_BGR2RGB)))
|
|
||||||
p2 = preprocess(Image.fromarray(cv2.cvtColor(crop2, cv2.COLOR_BGR2RGB)))
|
|
||||||
patches1.append(p1); patches2.append(p2); coords.append((x, y))
|
|
||||||
|
|
||||||
if not patches1: return
|
|
||||||
|
|
||||||
all_distances = []
|
|
||||||
for i in tqdm(range(0, len(patches1), batch_size), unit="batch"):
|
|
||||||
b1 = torch.stack(patches1[i:i+batch_size]).to(DEVICE, dtype=torch.float16)
|
|
||||||
b2 = torch.stack(patches2[i:i+batch_size]).to(DEVICE, dtype=torch.float16)
|
|
||||||
with torch.no_grad():
|
|
||||||
all_distances.append(model.compute_batch_distance(b1, b2).cpu())
|
|
||||||
|
|
||||||
distances = torch.cat(all_distances)
|
|
||||||
|
|
||||||
heatmap = np.zeros((h, w), dtype=np.float32)
|
|
||||||
count_map = np.zeros((h, w), dtype=np.float32)
|
|
||||||
max_score = 0
|
|
||||||
for idx, score in enumerate(distances):
|
|
||||||
val = score.item()
|
|
||||||
x, y = coords[idx]
|
|
||||||
if val > max_score: max_score = val
|
|
||||||
heatmap[y:y+patch_size, x:x+patch_size] += val
|
|
||||||
count_map[y:y+patch_size, x:x+patch_size] += 1
|
|
||||||
|
|
||||||
count_map[count_map == 0] = 1
|
|
||||||
heatmap_avg = heatmap / count_map
|
|
||||||
|
|
||||||
# 可视化
|
|
||||||
norm_factor = max(max_score, 0.1)
|
|
||||||
heatmap_vis = (heatmap_avg / norm_factor * 255).clip(0, 255).astype(np.uint8)
|
|
||||||
heatmap_color = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET)
|
|
||||||
blended_img = cv2.addWeighted(img2_cv, 0.4, heatmap_color, 0.6, 0)
|
|
||||||
|
|
||||||
_, thresh = cv2.threshold(heatmap_vis, int(255 * THRESHOLD), 255, cv2.THRESH_BINARY)
|
|
||||||
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
|
|
||||||
result_img = blended_img.copy()
|
|
||||||
for cnt in contours:
|
|
||||||
if cv2.contourArea(cnt) > (patch_size**2)*0.05:
|
|
||||||
x, y, bw, bh = cv2.boundingRect(cnt)
|
|
||||||
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4)
|
|
||||||
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2)
|
|
||||||
label = f"Diff: {heatmap_avg[y:y+bh, x:x+bw].mean():.2f}"
|
|
||||||
cv2.rectangle(result_img, (x, y-25), (x+130, y), (0,0,255), -1)
|
|
||||||
cv2.putText(result_img, label, (x+5, y-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
|
|
||||||
|
|
||||||
output_full_path = output_path if os.path.isabs(output_path) else os.path.join("/app/data", output_path)
|
|
||||||
os.makedirs(os.path.dirname(output_full_path), exist_ok=True)
|
|
||||||
cv2.imwrite(output_full_path, result_img)
|
|
||||||
print(f"🎯 完成! 最大差异分: {max_score:.4f}, 结果已保存: {output_full_path}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("t1"); parser.add_argument("t2"); parser.add_argument("out", nargs="?", default="result.jpg")
|
|
||||||
parser.add_argument("-c", "--crop", type=int, default=224)
|
|
||||||
parser.add_argument("-s", "--step", type=int, default=0)
|
|
||||||
parser.add_argument("-b", "--batch", type=int, default=16)
|
|
||||||
args = parser.parse_args()
|
|
||||||
scan_and_draw(DiffSimSemantic(DEVICE), args.t1, args.t2, args.out, args.crop, args.step if args.step>0 else args.crop//2, args.batch)
|
|
||||||
Loading…
Reference in New Issue
Block a user