puruan/main_dreamsim.py
2026-02-04 09:54:24 +08:00

186 lines
6.6 KiB
Python

import sys
import os
import torch
import cv2
import numpy as np
import argparse
from dreamsim import dreamsim
from PIL import Image
from tqdm import tqdm
# === 配置 ===
# DreamSim 官方推荐 ensemble 模式效果最好,虽然慢一点但更准
MODEL_TYPE = "ensemble"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def init_model():
print(f"🚀 [系统] 初始化 DreamSim ({MODEL_TYPE})...")
if DEVICE == "cuda":
print(f"✅ [硬件确认] 正在使用显卡: {torch.cuda.get_device_name(0)}")
print(f" (显存状态: {torch.cuda.memory_allocated()/1024**2:.2f}MB 已用)")
else:
print("❌ [警告] 未检测到显卡,正在使用 CPU 慢速运行!")
# 加载模型
model, preprocess = dreamsim(pretrained=True, dreamsim_type=MODEL_TYPE, device=DEVICE)
model.to(DEVICE)
return model, preprocess
def scan_and_draw(model, preprocess, t1_path, t2_path, output_path, patch_size, stride, batch_size, threshold):
# 1. OpenCV 读取
img1_cv = cv2.imread(t1_path)
img2_cv = cv2.imread(t2_path)
if img1_cv is None or img2_cv is None:
print("❌ 错误: 无法读取图片")
return
# 强制 Resize 对齐 (以现状图 T2 为准)
h, w = img2_cv.shape[:2]
img1_cv = cv2.resize(img1_cv, (w, h))
print(f"🔪 [切片] DreamSim 扫描... 尺寸: {w}x{h}")
print(f" - 参数: Crop={patch_size}, Step={stride}, Batch={batch_size}, Thresh={threshold}")
# 2. 准备滑动窗口
patches1 = []
patches2 = []
coords = []
for y in range(0, h - patch_size + 1, stride):
for x in range(0, w - patch_size + 1, stride):
crop1 = img1_cv[y:y+patch_size, x:x+patch_size]
crop2 = img2_cv[y:y+patch_size, x:x+patch_size]
# DreamSim 预处理
p1 = preprocess(Image.fromarray(cv2.cvtColor(crop1, cv2.COLOR_BGR2RGB)))
p2 = preprocess(Image.fromarray(cv2.cvtColor(crop2, cv2.COLOR_BGR2RGB)))
# 修正维度: preprocess 可能返回 [1, 3, 224, 224],我们需要 [3, 224, 224]
if p1.ndim == 4: p1 = p1.squeeze(0)
if p2.ndim == 4: p2 = p2.squeeze(0)
patches1.append(p1)
patches2.append(p2)
coords.append((x, y))
if not patches1:
print("⚠️ 图片太小,无法切片")
return
total_patches = len(patches1)
print(f"🧠 [推理] 共 {total_patches} 个切片,开始计算...")
all_distances = []
# 3. 批量推理 (使用 tqdm 显示进度)
for i in tqdm(range(0, total_patches, batch_size), unit="batch"):
batch_p1 = torch.stack(patches1[i : i + batch_size]).to(DEVICE)
batch_p2 = torch.stack(patches2[i : i + batch_size]).to(DEVICE)
with torch.no_grad():
# DreamSim 前向传播
dist_batch = model(batch_p1, batch_p2)
all_distances.append(dist_batch.cpu())
distances = torch.cat(all_distances)
# 4. 生成热力图数据
heatmap = np.zeros((h, w), dtype=np.float32)
count_map = np.zeros((h, w), dtype=np.float32)
# 统计信息
min_v, max_v = distances.min().item(), distances.max().item()
print(f"\n📊 [统计] 分数分布: Min={min_v:.4f} | Max={max_v:.4f} | Mean={distances.mean().item():.4f}")
for idx, score in enumerate(distances):
val = score.item()
x, y = coords[idx]
heatmap[y:y+patch_size, x:x+patch_size] += val
count_map[y:y+patch_size, x:x+patch_size] += 1
# 平均化重叠区域
count_map[count_map == 0] = 1
heatmap_avg = heatmap / count_map
# ==========================================
# 🔥 关键:保存原始灰度图 (供前端调试)
# ==========================================
raw_norm = (heatmap_avg - min_v) / (max_v - min_v + 1e-6)
cv2.imwrite("debug_raw_heatmap.png", (raw_norm * 255).astype(np.uint8))
print(f"💾 [调试] 原始热力图已保存: debug_raw_heatmap.png")
# ==========================================
# 5. 可视化后处理
# ==========================================
# 归一化 (使用 max_v 或固定因子)
norm_factor = max(max_v, 0.1)
heatmap_vis = (heatmap_avg / norm_factor * 255).clip(0, 255).astype(np.uint8)
# 色彩映射
heatmap_color = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET)
# 图像叠加
alpha = 0.4
blended_img = cv2.addWeighted(img2_cv, alpha, heatmap_color, 1.0 - alpha, 0)
# 阈值过滤与画框
# 使用传入的 threshold 参数
_, thresh_img = cv2.threshold(heatmap_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
result_img = blended_img.copy()
box_count = 0
for cnt in contours:
area = cv2.contourArea(cnt)
# 过滤过小的区域 (3% 的切片面积)
min_area = (patch_size * patch_size) * 0.03
if area > min_area:
box_count += 1
x, y, bw, bh = cv2.boundingRect(cnt)
# 画框
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4)
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2)
# 显示分数
label = f"{heatmap_avg[y:y+bh, x:x+bw].mean():.2f}"
cv2.rectangle(result_img, (x, y-25), (x+80, y), (0,0,255), -1)
cv2.putText(result_img, label, (x+5, y-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
# 保存最终结果
cv2.imwrite(output_path, result_img)
print("="*40)
print(f"🎯 扫描完成! 发现区域: {box_count}")
print(f"🖼️ 结果已保存至: {output_path}")
print("="*40)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="DreamSim 违建热力图检测 (标准化版)")
parser.add_argument("t1", help="基准图")
parser.add_argument("t2", help="现状图")
parser.add_argument("out", nargs="?", default="heatmap_result.jpg", help="输出图片名")
# 扫描参数
parser.add_argument("-c", "--crop", type=int, default=224, help="切片大小")
parser.add_argument("-s", "--step", type=int, default=0, help="步长")
parser.add_argument("-b", "--batch", type=int, default=16, help="批次")
# 核心参数
parser.add_argument("--thresh", type=float, default=0.30, help="检测阈值 (0.0-1.0)")
args = parser.parse_args()
# 自动计算步长
stride = args.step if args.step > 0 else args.crop // 2
# 初始化并运行
model, preprocess = init_model()
scan_and_draw(model, preprocess, args.t1, args.t2, args.out, args.crop, stride, args.batch, args.thresh)