puruan/main_dinov2_sliced.py
2026-02-04 09:54:24 +08:00

166 lines
6.0 KiB
Python

import sys
import os
import torch
import torch.nn.functional as F
import cv2
import numpy as np
import argparse
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
# === 配置 ===
MODEL_NAME = 'dinov2_vitg14_reg'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def init_model():
print(f"🚀 [系统] 初始化 DINOv2 ({MODEL_NAME})...")
if DEVICE == "cuda":
print(f"✅ [硬件] 使用设备: {torch.cuda.get_device_name(0)}")
# === 关键修正:强制使用本地缓存加载 ===
local_path = '/root/.cache/torch/hub/facebookresearch_dinov2_main'
if os.path.exists(local_path):
print(f"📂 [加载] 命中本地缓存: {local_path}")
model = torch.hub.load(local_path, MODEL_NAME, source='local')
else:
print("⚠️ 未找到本地缓存,尝试在线加载...")
model = torch.hub.load('facebookresearch/dinov2', MODEL_NAME)
model.to(DEVICE)
model.eval()
return model
def get_transform():
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# === 修正:增加 threshold 参数 ===
def scan_and_draw(model, t1_path, t2_path, output_path, patch_size, stride, batch_size, threshold):
img1_cv = cv2.imread(t1_path)
img2_cv = cv2.imread(t2_path)
if img1_cv is None or img2_cv is None:
print("❌ 错误: 无法读取图片")
return
# 强制对齐
h, w = img2_cv.shape[:2]
img1_cv = cv2.resize(img1_cv, (w, h))
print(f"🔪 [切片] DINOv2 扫描... 尺寸: {w}x{h}")
print(f" - 参数: Crop={patch_size}, Step={stride}, Thresh={threshold}")
# 准备切片
patches1_pil = []
patches2_pil = []
coords = []
for y in range(0, h - patch_size + 1, stride):
for x in range(0, w - patch_size + 1, stride):
crop1 = img1_cv[y:y+patch_size, x:x+patch_size]
crop2 = img2_cv[y:y+patch_size, x:x+patch_size]
p1 = Image.fromarray(cv2.cvtColor(crop1, cv2.COLOR_BGR2RGB))
p2 = Image.fromarray(cv2.cvtColor(crop2, cv2.COLOR_BGR2RGB))
patches1_pil.append(p1)
patches2_pil.append(p2)
coords.append((x, y))
if not patches1_pil:
print("⚠️ 图片太小,无法切片")
return
total_patches = len(patches1_pil)
print(f"🧠 [推理] 共 {total_patches} 个切片...")
all_distances = []
transform = get_transform()
for i in tqdm(range(0, total_patches, batch_size), unit="batch"):
batch_p1_list = [transform(p) for p in patches1_pil[i : i + batch_size]]
batch_p2_list = [transform(p) for p in patches2_pil[i : i + batch_size]]
if not batch_p1_list: break
batch_p1 = torch.stack(batch_p1_list).to(DEVICE)
batch_p2 = torch.stack(batch_p2_list).to(DEVICE)
with torch.no_grad():
feat1 = model.forward_features(batch_p1)["x_norm_clstoken"]
feat2 = model.forward_features(batch_p2)["x_norm_clstoken"]
sim_batch = F.cosine_similarity(feat1, feat2, dim=-1)
dist_batch = 1.0 - sim_batch
all_distances.append(dist_batch.cpu())
distances = torch.cat(all_distances)
# 重建热力图
heatmap = np.zeros((h, w), dtype=np.float32)
count_map = np.zeros((h, w), dtype=np.float32)
max_score = distances.max().item()
for idx, score in enumerate(distances):
val = score.item()
x, y = coords[idx]
heatmap[y:y+patch_size, x:x+patch_size] += val
count_map[y:y+patch_size, x:x+patch_size] += 1
count_map[count_map == 0] = 1
heatmap_avg = heatmap / count_map
# 可视化
norm_denom = max(max_score, 0.4)
heatmap_vis = (heatmap_avg / norm_denom * 255).clip(0, 255).astype(np.uint8)
heatmap_color = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET)
blended_img = cv2.addWeighted(img2_cv, 0.4, heatmap_color, 0.6, 0)
# === 使用传入的 threshold 参数 ===
_, thresh_img = cv2.threshold(heatmap_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
result_img = blended_img.copy()
box_count = 0
for cnt in contours:
area = cv2.contourArea(cnt)
if area > (patch_size * patch_size) * 0.03:
box_count += 1
x, y, bw, bh = cv2.boundingRect(cnt)
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4)
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2)
score_val = heatmap_avg[y:y+bh, x:x+bw].mean()
cv2.rectangle(result_img, (x, y-25), (x+130, y), (0,0,255), -1)
cv2.putText(result_img, f"Diff: {score_val:.2f}", (x+5, y-7),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
cv2.imwrite(output_path, result_img)
print("="*40)
print(f"🎯 检测完成! 最大差异: {max_score:.4f} | 发现区域: {box_count}")
print(f"🖼️ 结果: {output_path}")
print("="*40)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="DINOv2 Giant 切片版")
parser.add_argument("t1", help="基准图")
parser.add_argument("t2", help="现状图")
parser.add_argument("out", nargs="?", default="dino_sliced_result.jpg")
parser.add_argument("-c", "--crop", type=int, default=224, help="切片大小")
parser.add_argument("-s", "--step", type=int, default=0, help="步长")
parser.add_argument("-b", "--batch", type=int, default=8, help="批次")
# === 修正:添加 --thresh 参数接口 ===
parser.add_argument("--thresh", type=float, default=0.30, help="检测阈值")
args = parser.parse_args()
stride = args.step if args.step > 0 else args.crop // 2
model = init_model()
# 传入 args.thresh
scan_and_draw(model, args.t1, args.t2, args.out, args.crop, stride, args.batch, args.thresh)