186 lines
6.6 KiB
Python
186 lines
6.6 KiB
Python
import sys
|
|
import os
|
|
import torch
|
|
import cv2
|
|
import numpy as np
|
|
import argparse
|
|
from dreamsim import dreamsim
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
|
|
# === 配置 ===
|
|
# DreamSim 官方推荐 ensemble 模式效果最好,虽然慢一点但更准
|
|
MODEL_TYPE = "ensemble"
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
def init_model():
|
|
print(f"🚀 [系统] 初始化 DreamSim ({MODEL_TYPE})...")
|
|
if DEVICE == "cuda":
|
|
print(f"✅ [硬件确认] 正在使用显卡: {torch.cuda.get_device_name(0)}")
|
|
print(f" (显存状态: {torch.cuda.memory_allocated()/1024**2:.2f}MB 已用)")
|
|
else:
|
|
print("❌ [警告] 未检测到显卡,正在使用 CPU 慢速运行!")
|
|
|
|
# 加载模型
|
|
model, preprocess = dreamsim(pretrained=True, dreamsim_type=MODEL_TYPE, device=DEVICE)
|
|
model.to(DEVICE)
|
|
|
|
return model, preprocess
|
|
|
|
def scan_and_draw(model, preprocess, t1_path, t2_path, output_path, patch_size, stride, batch_size, threshold):
|
|
# 1. OpenCV 读取
|
|
img1_cv = cv2.imread(t1_path)
|
|
img2_cv = cv2.imread(t2_path)
|
|
|
|
if img1_cv is None or img2_cv is None:
|
|
print("❌ 错误: 无法读取图片")
|
|
return
|
|
|
|
# 强制 Resize 对齐 (以现状图 T2 为准)
|
|
h, w = img2_cv.shape[:2]
|
|
img1_cv = cv2.resize(img1_cv, (w, h))
|
|
|
|
print(f"🔪 [切片] DreamSim 扫描... 尺寸: {w}x{h}")
|
|
print(f" - 参数: Crop={patch_size}, Step={stride}, Batch={batch_size}, Thresh={threshold}")
|
|
|
|
# 2. 准备滑动窗口
|
|
patches1 = []
|
|
patches2 = []
|
|
coords = []
|
|
|
|
for y in range(0, h - patch_size + 1, stride):
|
|
for x in range(0, w - patch_size + 1, stride):
|
|
crop1 = img1_cv[y:y+patch_size, x:x+patch_size]
|
|
crop2 = img2_cv[y:y+patch_size, x:x+patch_size]
|
|
|
|
# DreamSim 预处理
|
|
p1 = preprocess(Image.fromarray(cv2.cvtColor(crop1, cv2.COLOR_BGR2RGB)))
|
|
p2 = preprocess(Image.fromarray(cv2.cvtColor(crop2, cv2.COLOR_BGR2RGB)))
|
|
|
|
# 修正维度: preprocess 可能返回 [1, 3, 224, 224],我们需要 [3, 224, 224]
|
|
if p1.ndim == 4: p1 = p1.squeeze(0)
|
|
if p2.ndim == 4: p2 = p2.squeeze(0)
|
|
|
|
patches1.append(p1)
|
|
patches2.append(p2)
|
|
coords.append((x, y))
|
|
|
|
if not patches1:
|
|
print("⚠️ 图片太小,无法切片")
|
|
return
|
|
|
|
total_patches = len(patches1)
|
|
print(f"🧠 [推理] 共 {total_patches} 个切片,开始计算...")
|
|
|
|
all_distances = []
|
|
|
|
# 3. 批量推理 (使用 tqdm 显示进度)
|
|
for i in tqdm(range(0, total_patches, batch_size), unit="batch"):
|
|
batch_p1 = torch.stack(patches1[i : i + batch_size]).to(DEVICE)
|
|
batch_p2 = torch.stack(patches2[i : i + batch_size]).to(DEVICE)
|
|
|
|
with torch.no_grad():
|
|
# DreamSim 前向传播
|
|
dist_batch = model(batch_p1, batch_p2)
|
|
all_distances.append(dist_batch.cpu())
|
|
|
|
distances = torch.cat(all_distances)
|
|
|
|
# 4. 生成热力图数据
|
|
heatmap = np.zeros((h, w), dtype=np.float32)
|
|
count_map = np.zeros((h, w), dtype=np.float32)
|
|
|
|
# 统计信息
|
|
min_v, max_v = distances.min().item(), distances.max().item()
|
|
print(f"\n📊 [统计] 分数分布: Min={min_v:.4f} | Max={max_v:.4f} | Mean={distances.mean().item():.4f}")
|
|
|
|
for idx, score in enumerate(distances):
|
|
val = score.item()
|
|
x, y = coords[idx]
|
|
|
|
heatmap[y:y+patch_size, x:x+patch_size] += val
|
|
count_map[y:y+patch_size, x:x+patch_size] += 1
|
|
|
|
# 平均化重叠区域
|
|
count_map[count_map == 0] = 1
|
|
heatmap_avg = heatmap / count_map
|
|
|
|
# ==========================================
|
|
# 🔥 关键:保存原始灰度图 (供前端调试)
|
|
# ==========================================
|
|
raw_norm = (heatmap_avg - min_v) / (max_v - min_v + 1e-6)
|
|
cv2.imwrite("debug_raw_heatmap.png", (raw_norm * 255).astype(np.uint8))
|
|
print(f"💾 [调试] 原始热力图已保存: debug_raw_heatmap.png")
|
|
|
|
# ==========================================
|
|
# 5. 可视化后处理
|
|
# ==========================================
|
|
|
|
# 归一化 (使用 max_v 或固定因子)
|
|
norm_factor = max(max_v, 0.1)
|
|
heatmap_vis = (heatmap_avg / norm_factor * 255).clip(0, 255).astype(np.uint8)
|
|
|
|
# 色彩映射
|
|
heatmap_color = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET)
|
|
|
|
# 图像叠加
|
|
alpha = 0.4
|
|
blended_img = cv2.addWeighted(img2_cv, alpha, heatmap_color, 1.0 - alpha, 0)
|
|
|
|
# 阈值过滤与画框
|
|
# 使用传入的 threshold 参数
|
|
_, thresh_img = cv2.threshold(heatmap_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
|
|
contours, _ = cv2.findContours(thresh_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
result_img = blended_img.copy()
|
|
box_count = 0
|
|
|
|
for cnt in contours:
|
|
area = cv2.contourArea(cnt)
|
|
# 过滤过小的区域 (3% 的切片面积)
|
|
min_area = (patch_size * patch_size) * 0.03
|
|
|
|
if area > min_area:
|
|
box_count += 1
|
|
x, y, bw, bh = cv2.boundingRect(cnt)
|
|
|
|
# 画框
|
|
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4)
|
|
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2)
|
|
|
|
# 显示分数
|
|
label = f"{heatmap_avg[y:y+bh, x:x+bw].mean():.2f}"
|
|
cv2.rectangle(result_img, (x, y-25), (x+80, y), (0,0,255), -1)
|
|
cv2.putText(result_img, label, (x+5, y-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
|
|
|
|
# 保存最终结果
|
|
cv2.imwrite(output_path, result_img)
|
|
|
|
print("="*40)
|
|
print(f"🎯 扫描完成! 发现区域: {box_count} 个")
|
|
print(f"🖼️ 结果已保存至: {output_path}")
|
|
print("="*40)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="DreamSim 违建热力图检测 (标准化版)")
|
|
parser.add_argument("t1", help="基准图")
|
|
parser.add_argument("t2", help="现状图")
|
|
parser.add_argument("out", nargs="?", default="heatmap_result.jpg", help="输出图片名")
|
|
|
|
# 扫描参数
|
|
parser.add_argument("-c", "--crop", type=int, default=224, help="切片大小")
|
|
parser.add_argument("-s", "--step", type=int, default=0, help="步长")
|
|
parser.add_argument("-b", "--batch", type=int, default=16, help="批次")
|
|
|
|
# 核心参数
|
|
parser.add_argument("--thresh", type=float, default=0.30, help="检测阈值 (0.0-1.0)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# 自动计算步长
|
|
stride = args.step if args.step > 0 else args.crop // 2
|
|
|
|
# 初始化并运行
|
|
model, preprocess = init_model()
|
|
scan_and_draw(model, preprocess, args.t1, args.t2, args.out, args.crop, stride, args.batch, args.thresh)
|