166 lines
6.0 KiB
Python
166 lines
6.0 KiB
Python
import sys
|
|
import os
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import cv2
|
|
import numpy as np
|
|
import argparse
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
from tqdm import tqdm
|
|
|
|
# === 配置 ===
|
|
MODEL_NAME = 'dinov2_vitg14_reg'
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
def init_model():
|
|
print(f"🚀 [系统] 初始化 DINOv2 ({MODEL_NAME})...")
|
|
if DEVICE == "cuda":
|
|
print(f"✅ [硬件] 使用设备: {torch.cuda.get_device_name(0)}")
|
|
|
|
# === 关键修正:强制使用本地缓存加载 ===
|
|
local_path = '/root/.cache/torch/hub/facebookresearch_dinov2_main'
|
|
if os.path.exists(local_path):
|
|
print(f"📂 [加载] 命中本地缓存: {local_path}")
|
|
model = torch.hub.load(local_path, MODEL_NAME, source='local')
|
|
else:
|
|
print("⚠️ 未找到本地缓存,尝试在线加载...")
|
|
model = torch.hub.load('facebookresearch/dinov2', MODEL_NAME)
|
|
|
|
model.to(DEVICE)
|
|
model.eval()
|
|
return model
|
|
|
|
def get_transform():
|
|
return transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
# === 修正:增加 threshold 参数 ===
|
|
def scan_and_draw(model, t1_path, t2_path, output_path, patch_size, stride, batch_size, threshold):
|
|
img1_cv = cv2.imread(t1_path)
|
|
img2_cv = cv2.imread(t2_path)
|
|
|
|
if img1_cv is None or img2_cv is None:
|
|
print("❌ 错误: 无法读取图片")
|
|
return
|
|
|
|
# 强制对齐
|
|
h, w = img2_cv.shape[:2]
|
|
img1_cv = cv2.resize(img1_cv, (w, h))
|
|
|
|
print(f"🔪 [切片] DINOv2 扫描... 尺寸: {w}x{h}")
|
|
print(f" - 参数: Crop={patch_size}, Step={stride}, Thresh={threshold}")
|
|
|
|
# 准备切片
|
|
patches1_pil = []
|
|
patches2_pil = []
|
|
coords = []
|
|
|
|
for y in range(0, h - patch_size + 1, stride):
|
|
for x in range(0, w - patch_size + 1, stride):
|
|
crop1 = img1_cv[y:y+patch_size, x:x+patch_size]
|
|
crop2 = img2_cv[y:y+patch_size, x:x+patch_size]
|
|
|
|
p1 = Image.fromarray(cv2.cvtColor(crop1, cv2.COLOR_BGR2RGB))
|
|
p2 = Image.fromarray(cv2.cvtColor(crop2, cv2.COLOR_BGR2RGB))
|
|
|
|
patches1_pil.append(p1)
|
|
patches2_pil.append(p2)
|
|
coords.append((x, y))
|
|
|
|
if not patches1_pil:
|
|
print("⚠️ 图片太小,无法切片")
|
|
return
|
|
|
|
total_patches = len(patches1_pil)
|
|
print(f"🧠 [推理] 共 {total_patches} 个切片...")
|
|
|
|
all_distances = []
|
|
transform = get_transform()
|
|
|
|
for i in tqdm(range(0, total_patches, batch_size), unit="batch"):
|
|
batch_p1_list = [transform(p) for p in patches1_pil[i : i + batch_size]]
|
|
batch_p2_list = [transform(p) for p in patches2_pil[i : i + batch_size]]
|
|
|
|
if not batch_p1_list: break
|
|
|
|
batch_p1 = torch.stack(batch_p1_list).to(DEVICE)
|
|
batch_p2 = torch.stack(batch_p2_list).to(DEVICE)
|
|
|
|
with torch.no_grad():
|
|
feat1 = model.forward_features(batch_p1)["x_norm_clstoken"]
|
|
feat2 = model.forward_features(batch_p2)["x_norm_clstoken"]
|
|
sim_batch = F.cosine_similarity(feat1, feat2, dim=-1)
|
|
dist_batch = 1.0 - sim_batch
|
|
all_distances.append(dist_batch.cpu())
|
|
|
|
distances = torch.cat(all_distances)
|
|
|
|
# 重建热力图
|
|
heatmap = np.zeros((h, w), dtype=np.float32)
|
|
count_map = np.zeros((h, w), dtype=np.float32)
|
|
max_score = distances.max().item()
|
|
|
|
for idx, score in enumerate(distances):
|
|
val = score.item()
|
|
x, y = coords[idx]
|
|
heatmap[y:y+patch_size, x:x+patch_size] += val
|
|
count_map[y:y+patch_size, x:x+patch_size] += 1
|
|
|
|
count_map[count_map == 0] = 1
|
|
heatmap_avg = heatmap / count_map
|
|
|
|
# 可视化
|
|
norm_denom = max(max_score, 0.4)
|
|
heatmap_vis = (heatmap_avg / norm_denom * 255).clip(0, 255).astype(np.uint8)
|
|
heatmap_color = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET)
|
|
blended_img = cv2.addWeighted(img2_cv, 0.4, heatmap_color, 0.6, 0)
|
|
|
|
# === 使用传入的 threshold 参数 ===
|
|
_, thresh_img = cv2.threshold(heatmap_vis, int(255 * threshold), 255, cv2.THRESH_BINARY)
|
|
contours, _ = cv2.findContours(thresh_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
result_img = blended_img.copy()
|
|
box_count = 0
|
|
|
|
for cnt in contours:
|
|
area = cv2.contourArea(cnt)
|
|
if area > (patch_size * patch_size) * 0.03:
|
|
box_count += 1
|
|
x, y, bw, bh = cv2.boundingRect(cnt)
|
|
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4)
|
|
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2)
|
|
|
|
score_val = heatmap_avg[y:y+bh, x:x+bw].mean()
|
|
cv2.rectangle(result_img, (x, y-25), (x+130, y), (0,0,255), -1)
|
|
cv2.putText(result_img, f"Diff: {score_val:.2f}", (x+5, y-7),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
|
|
|
|
cv2.imwrite(output_path, result_img)
|
|
print("="*40)
|
|
print(f"🎯 检测完成! 最大差异: {max_score:.4f} | 发现区域: {box_count}")
|
|
print(f"🖼️ 结果: {output_path}")
|
|
print("="*40)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="DINOv2 Giant 切片版")
|
|
parser.add_argument("t1", help="基准图")
|
|
parser.add_argument("t2", help="现状图")
|
|
parser.add_argument("out", nargs="?", default="dino_sliced_result.jpg")
|
|
parser.add_argument("-c", "--crop", type=int, default=224, help="切片大小")
|
|
parser.add_argument("-s", "--step", type=int, default=0, help="步长")
|
|
parser.add_argument("-b", "--batch", type=int, default=8, help="批次")
|
|
|
|
# === 修正:添加 --thresh 参数接口 ===
|
|
parser.add_argument("--thresh", type=float, default=0.30, help="检测阈值")
|
|
|
|
args = parser.parse_args()
|
|
stride = args.step if args.step > 0 else args.crop // 2
|
|
|
|
model = init_model()
|
|
# 传入 args.thresh
|
|
scan_and_draw(model, args.t1, args.t2, args.out, args.crop, stride, args.batch, args.thresh)
|