import os # 🔥 强制设置 HF 镜像 os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" import sys import torch import torch.nn as nn import torch.nn.functional as F import cv2 import numpy as np import argparse from PIL import Image from torchvision import transforms from diffusers import StableDiffusionPipeline from tqdm import tqdm # === 配置 === MODEL_ID = "Manojb/stable-diffusion-2-1-base" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" THRESHOLD = 0.40 # ⬆️ 稍微调高阈值,进一步过滤误报 IMG_RESIZE = 224 class DiffSimSemantic(nn.Module): def __init__(self, device): super().__init__() print(f"🚀 [系统] 初始化 DiffSim (语义增强版)...") self.pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16).to(device) self.pipe.set_progress_bar_config(disable=True) # 冻结参数 self.pipe.vae.requires_grad_(False) self.pipe.unet.requires_grad_(False) self.pipe.text_encoder.requires_grad_(False) # 预计算空文本 Embedding with torch.no_grad(): prompt = "" text_inputs = self.pipe.tokenizer( prompt, padding="max_length", max_length=self.pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) self.empty_text_embeds = self.pipe.text_encoder(text_input_ids)[0] self.features = {} # 🔥 修改 Hooks:只抓取深层特征,忽略浅层纹理 # up_blocks.1 (纹理层) -> ❌ 移除,太敏感,容易误报 # up_blocks.2 (结构层) -> ✅ 保留,判断形状 # up_blocks.3 (语义层) -> ✅ 核心,判断物体类别 for name, layer in self.pipe.unet.named_modules(): # 我们不再 Hook up_blocks.1,因为它对光照和纹理太敏感 if "up_blocks.2" in name and name.endswith("resnets.2"): layer.register_forward_hook(self.get_hook("feat_structure")) elif "up_blocks.3" in name and name.endswith("resnets.2"): layer.register_forward_hook(self.get_hook("feat_semantic")) def get_hook(self, name): def hook(model, input, output): self.features[name] = output return hook def extract_features(self, images): latents = self.pipe.vae.encode(images).latent_dist.sample() * self.pipe.vae.config.scaling_factor batch_size = latents.shape[0] t = torch.zeros(batch_size, device=DEVICE, dtype=torch.long) encoder_hidden_states = self.empty_text_embeds.expand(batch_size, -1, -1) self.pipe.unet(latents, t, encoder_hidden_states=encoder_hidden_states) return {k: v.clone() for k, v in self.features.items()} def robust_similarity(self, f1, f2, kernel_size=5): """ 抗视差匹配: 🔥 将 kernel_size 默认值提升到 5 允许更大的几何错位(应对30m高度的视差) """ f1 = F.normalize(f1, dim=1) f2 = F.normalize(f2, dim=1) padding = kernel_size // 2 b, c, h, w = f2.shape f2_unfolded = F.unfold(f2, kernel_size=kernel_size, padding=padding) f2_unfolded = f2_unfolded.view(b, c, kernel_size*kernel_size, h, w) sim_map = (f1.unsqueeze(2) * f2_unfolded).sum(dim=1) max_sim, _ = sim_map.max(dim=1) return max_sim def compute_batch_distance(self, batch_p1, batch_p2): feat_a = self.extract_features(batch_p1) feat_b = self.extract_features(batch_p2) total_score = 0 # 🔥 调整后的权重策略:纯粹关注结构和语义 # 0.0 -> 纹理 (彻底忽略颜色深浅、阴影) # 0.4 -> 结构 (feat_structure): 关注形状变化 # 0.6 -> 语义 (feat_semantic): 关注物体存在性 (最像 DreamSim 的部分) weights = {"feat_structure": 0.4, "feat_semantic": 0.6} for name, w in weights.items(): fa, fb = feat_a[name].float(), feat_b[name].float() # 对所有层都启用抗视差匹配,增加鲁棒性 # kernel_size=5 能容忍更大的像素位移 sim_map = self.robust_similarity(fa, fb, kernel_size=5) dist = 1 - sim_map.mean(dim=[1, 2]) total_score += dist * w return total_score # ========================================== # 辅助函数 (保持不变) # ========================================== def get_transforms(): return transforms.Compose([ transforms.Resize((IMG_RESIZE, IMG_RESIZE)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) def scan_and_draw(model, t1_path, t2_path, output_path, patch_size, stride, batch_size): img1_cv = cv2.imread(t1_path) img2_cv = cv2.imread(t2_path) if img1_cv is None or img2_cv is None: return h, w = img2_cv.shape[:2] img1_cv = cv2.resize(img1_cv, (w, h)) preprocess = get_transforms() print(f"🔪 [切片] 开始扫描... 尺寸: {w}x{h}, 忽略纹理细节,专注语义差异") patches1, patches2, coords = [], [], [] for y in range(0, h - patch_size + 1, stride): for x in range(0, w - patch_size + 1, stride): crop1 = img1_cv[y:y+patch_size, x:x+patch_size] crop2 = img2_cv[y:y+patch_size, x:x+patch_size] p1 = preprocess(Image.fromarray(cv2.cvtColor(crop1, cv2.COLOR_BGR2RGB))) p2 = preprocess(Image.fromarray(cv2.cvtColor(crop2, cv2.COLOR_BGR2RGB))) patches1.append(p1); patches2.append(p2); coords.append((x, y)) if not patches1: return all_distances = [] for i in tqdm(range(0, len(patches1), batch_size), unit="batch"): b1 = torch.stack(patches1[i:i+batch_size]).to(DEVICE, dtype=torch.float16) b2 = torch.stack(patches2[i:i+batch_size]).to(DEVICE, dtype=torch.float16) with torch.no_grad(): all_distances.append(model.compute_batch_distance(b1, b2).cpu()) distances = torch.cat(all_distances) heatmap = np.zeros((h, w), dtype=np.float32) count_map = np.zeros((h, w), dtype=np.float32) max_score = 0 for idx, score in enumerate(distances): val = score.item() x, y = coords[idx] if val > max_score: max_score = val heatmap[y:y+patch_size, x:x+patch_size] += val count_map[y:y+patch_size, x:x+patch_size] += 1 count_map[count_map == 0] = 1 heatmap_avg = heatmap / count_map # 可视化 norm_factor = max(max_score, 0.1) heatmap_vis = (heatmap_avg / norm_factor * 255).clip(0, 255).astype(np.uint8) heatmap_color = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET) blended_img = cv2.addWeighted(img2_cv, 0.4, heatmap_color, 0.6, 0) _, thresh = cv2.threshold(heatmap_vis, int(255 * THRESHOLD), 255, cv2.THRESH_BINARY) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) result_img = blended_img.copy() for cnt in contours: if cv2.contourArea(cnt) > (patch_size**2)*0.05: x, y, bw, bh = cv2.boundingRect(cnt) cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4) cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2) label = f"Diff: {heatmap_avg[y:y+bh, x:x+bw].mean():.2f}" cv2.rectangle(result_img, (x, y-25), (x+130, y), (0,0,255), -1) cv2.putText(result_img, label, (x+5, y-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2) output_full_path = output_path if os.path.isabs(output_path) else os.path.join("/app/data", output_path) os.makedirs(os.path.dirname(output_full_path), exist_ok=True) cv2.imwrite(output_full_path, result_img) print(f"🎯 完成! 最大差异分: {max_score:.4f}, 结果已保存: {output_full_path}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("t1"); parser.add_argument("t2"); parser.add_argument("out", nargs="?", default="result.jpg") parser.add_argument("-c", "--crop", type=int, default=224) parser.add_argument("-s", "--step", type=int, default=0) parser.add_argument("-b", "--batch", type=int, default=16) args = parser.parse_args() scan_and_draw(DiffSimSemantic(DEVICE), args.t1, args.t2, args.out, args.crop, args.step if args.step>0 else args.crop//2, args.batch)