import sys import os import torch import torch.nn.functional as F import cv2 import numpy as np import argparse from PIL import Image from torchvision import transforms from tqdm import tqdm # === 配置 === MODEL_NAME = 'dinov2_vitg14_reg' DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def init_model(): print(f"🚀 [系统] 初始化 DINOv2 ({MODEL_NAME})...") if DEVICE == "cuda": print(f"✅ [硬件] 使用设备: {torch.cuda.get_device_name(0)}") # === 关键修正:强制使用本地缓存加载 === local_path = '/root/.cache/torch/hub/facebookresearch_dinov2_main' if os.path.exists(local_path): print(f"📂 [加载] 命中本地缓存: {local_path}") model = torch.hub.load(local_path, MODEL_NAME, source='local') else: print("⚠️ 未找到本地缓存,尝试在线加载...") model = torch.hub.load('facebookresearch/dinov2', MODEL_NAME) model.to(DEVICE) model.eval() return model def get_transform(): return transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # === 修正:增加 threshold 参数 === def scan_and_draw(model, t1_path, t2_path, output_path, patch_size, stride, batch_size, threshold): img1_cv = cv2.imread(t1_path) img2_cv = cv2.imread(t2_path) if img1_cv is None or img2_cv is None: print("❌ 错误: 无法读取图片") return # 强制对齐 h, w = img2_cv.shape[:2] img1_cv = cv2.resize(img1_cv, (w, h)) print(f"🔪 [切片] DINOv2 扫描... 尺寸: {w}x{h}") print(f" - 参数: Crop={patch_size}, Step={stride}, Thresh={threshold}") # 准备切片 patches1_pil = [] patches2_pil = [] coords = [] for y in range(0, h - patch_size + 1, stride): for x in range(0, w - patch_size + 1, stride): crop1 = img1_cv[y:y+patch_size, x:x+patch_size] crop2 = img2_cv[y:y+patch_size, x:x+patch_size] p1 = Image.fromarray(cv2.cvtColor(crop1, cv2.COLOR_BGR2RGB)) p2 = Image.fromarray(cv2.cvtColor(crop2, cv2.COLOR_BGR2RGB)) patches1_pil.append(p1) patches2_pil.append(p2) coords.append((x, y)) if not patches1_pil: print("⚠️ 图片太小,无法切片") return total_patches = len(patches1_pil) print(f"🧠 [推理] 共 {total_patches} 个切片...") all_distances = [] transform = get_transform() for i in tqdm(range(0, total_patches, batch_size), unit="batch"): batch_p1_list = [transform(p) for p in patches1_pil[i : i + batch_size]] batch_p2_list = [transform(p) for p in patches2_pil[i : i + batch_size]] if not batch_p1_list: break batch_p1 = torch.stack(batch_p1_list).to(DEVICE) batch_p2 = torch.stack(batch_p2_list).to(DEVICE) with torch.no_grad(): feat1 = model.forward_features(batch_p1)["x_norm_clstoken"] feat2 = model.forward_features(batch_p2)["x_norm_clstoken"] sim_batch = F.cosine_similarity(feat1, feat2, dim=-1) dist_batch = 1.0 - sim_batch all_distances.append(dist_batch.cpu()) distances = torch.cat(all_distances) # 重建热力图 heatmap = np.zeros((h, w), dtype=np.float32) count_map = np.zeros((h, w), dtype=np.float32) max_score = distances.max().item() for idx, score in enumerate(distances): val = score.item() x, y = coords[idx] heatmap[y:y+patch_size, x:x+patch_size] += val count_map[y:y+patch_size, x:x+patch_size] += 1 count_map[count_map == 0] = 1 heatmap_avg = heatmap / count_map # 可视化 norm_denom = max(max_score, 0.4) heatmap_vis = (heatmap_avg / norm_denom * 255).clip(0, 255).astype(np.uint8) heatmap_color = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET) blended_img = cv2.addWeighted(img2_cv, 0.4, heatmap_color, 0.6, 0) # === 使用传入的 threshold 参数 === _, thresh_img = cv2.threshold(heatmap_vis, int(255 * threshold), 255, cv2.THRESH_BINARY) contours, _ = cv2.findContours(thresh_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) result_img = blended_img.copy() box_count = 0 for cnt in contours: area = cv2.contourArea(cnt) if area > (patch_size * patch_size) * 0.03: box_count += 1 x, y, bw, bh = cv2.boundingRect(cnt) cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4) cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2) score_val = heatmap_avg[y:y+bh, x:x+bw].mean() cv2.rectangle(result_img, (x, y-25), (x+130, y), (0,0,255), -1) cv2.putText(result_img, f"Diff: {score_val:.2f}", (x+5, y-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2) cv2.imwrite(output_path, result_img) print("="*40) print(f"🎯 检测完成! 最大差异: {max_score:.4f} | 发现区域: {box_count}") print(f"🖼️ 结果: {output_path}") print("="*40) if __name__ == "__main__": parser = argparse.ArgumentParser(description="DINOv2 Giant 切片版") parser.add_argument("t1", help="基准图") parser.add_argument("t2", help="现状图") parser.add_argument("out", nargs="?", default="dino_sliced_result.jpg") parser.add_argument("-c", "--crop", type=int, default=224, help="切片大小") parser.add_argument("-s", "--step", type=int, default=0, help="步长") parser.add_argument("-b", "--batch", type=int, default=8, help="批次") # === 修正:添加 --thresh 参数接口 === parser.add_argument("--thresh", type=float, default=0.30, help="检测阈值") args = parser.parse_args() stride = args.step if args.step > 0 else args.crop // 2 model = init_model() # 传入 args.thresh scan_and_draw(model, args.t1, args.t2, args.out, args.crop, stride, args.batch, args.thresh)