294 lines
11 KiB
Python
294 lines
11 KiB
Python
import os
|
|
# 🚀 强制使用国内镜像
|
|
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import cv2
|
|
import numpy as np
|
|
import argparse
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
from diffusers import StableDiffusionPipeline
|
|
from tqdm import tqdm
|
|
|
|
# =========================================================================
|
|
# PART 1: DiffSim 官方核心逻辑还原
|
|
# 基于: https://github.com/showlab/DiffSim/blob/main/diffsim/models/diffsim.py
|
|
# =========================================================================
|
|
|
|
class DiffSim(nn.Module):
|
|
def __init__(self, model_id="Manojb/stable-diffusion-2-1-base", device="cuda"):
|
|
super().__init__()
|
|
self.device = device
|
|
print(f"🚀 [Core] Loading Official DiffSim Logic (Backbone: {model_id})...")
|
|
|
|
# 1. 加载 SD 模型
|
|
try:
|
|
self.pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
|
|
except Exception as e:
|
|
print(f"❌ 模型加载失败,尝试加载默认 ID... Error: {e}")
|
|
self.pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16).to(device)
|
|
|
|
self.pipe.set_progress_bar_config(disable=True)
|
|
|
|
# 2. 冻结参数 (Freeze)
|
|
self.pipe.vae.requires_grad_(False)
|
|
self.pipe.unet.requires_grad_(False)
|
|
self.pipe.text_encoder.requires_grad_(False)
|
|
|
|
# 3. 预计算空文本 Embedding (Unconditional Guidance)
|
|
with torch.no_grad():
|
|
prompt = ""
|
|
text_input = self.pipe.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
|
|
self.empty_embeds = self.pipe.text_encoder(text_input.input_ids.to(device))[0]
|
|
|
|
self.features = {}
|
|
self._register_official_hooks()
|
|
|
|
def _register_official_hooks(self):
|
|
"""
|
|
DiffSim 官方策略: 提取 up_blocks.1 (Semantic) 和 up_blocks.2 (Structure)
|
|
"""
|
|
self.target_layers = {
|
|
"up_blocks.1.resnets.1": "feat_semantic", # 语义层
|
|
"up_blocks.2.resnets.1": "feat_structure" # 结构层
|
|
}
|
|
|
|
print(f"🔧 [Hook] Registered Layers: {list(self.target_layers.values())}")
|
|
|
|
for name, layer in self.pipe.unet.named_modules():
|
|
if name in self.target_layers:
|
|
alias = self.target_layers[name]
|
|
layer.register_forward_hook(self._get_hook(alias))
|
|
|
|
def _get_hook(self, name):
|
|
def hook(model, input, output):
|
|
self.features[name] = output
|
|
return hook
|
|
|
|
def extract_features(self, images):
|
|
# VAE Encoding
|
|
latents = self.pipe.vae.encode(images).latent_dist.sample() * self.pipe.vae.config.scaling_factor
|
|
|
|
# UNet Inference
|
|
t = torch.zeros(latents.shape[0], device=self.device, dtype=torch.long)
|
|
encoder_hidden_states = self.empty_embeds.expand(latents.shape[0], -1, -1)
|
|
|
|
self.features = {} # Reset buffer
|
|
self.pipe.unet(latents, t, encoder_hidden_states=encoder_hidden_states)
|
|
|
|
return {k: v.clone() for k, v in self.features.items()}
|
|
|
|
def calculate_robust_similarity(self, feat_a, feat_b, kernel_size=3):
|
|
"""
|
|
官方核心算法: Spatially Robust Similarity
|
|
公式: S(p) = max_{q in Neighbor(p)} cos(F1(p), F2(q))
|
|
"""
|
|
# Normalize vectors
|
|
feat_a = F.normalize(feat_a, dim=1)
|
|
feat_b = F.normalize(feat_b, dim=1)
|
|
|
|
if kernel_size <= 1:
|
|
# 严格对齐 (Pixel-wise Cosine Similarity)
|
|
return (feat_a * feat_b).sum(dim=1)
|
|
|
|
# 邻域搜索 (Sliding Window Matching)
|
|
b, c, h, w = feat_b.shape
|
|
padding = kernel_size // 2
|
|
|
|
# Unfold feature B to find neighbors
|
|
feat_b_unfolded = F.unfold(feat_b, kernel_size=kernel_size, padding=padding)
|
|
feat_b_unfolded = feat_b_unfolded.view(b, c, kernel_size*kernel_size, h, w)
|
|
|
|
# Calculate cosine sim between A and all neighbors of B
|
|
# Shape: [B, K*K, H, W]
|
|
sim_map = (feat_a.unsqueeze(2) * feat_b_unfolded).sum(dim=1)
|
|
|
|
# Take the best match (Max Pooling logic)
|
|
best_sim, _ = sim_map.max(dim=1)
|
|
|
|
return best_sim
|
|
|
|
def forward(self, batch_t1, batch_t2, w_struct, w_sem, kernel_size):
|
|
f1 = self.extract_features(batch_t1)
|
|
f2 = self.extract_features(batch_t2)
|
|
|
|
total_dist = 0
|
|
|
|
# Semantic Distance
|
|
if w_sem > 0 and "feat_semantic" in f1:
|
|
sim = self.calculate_robust_similarity(f1["feat_semantic"], f2["feat_semantic"], kernel_size)
|
|
dist = 1.0 - sim
|
|
total_dist += dist.mean(dim=[1, 2]) * w_sem
|
|
|
|
# Structure Distance
|
|
if w_struct > 0 and "feat_structure" in f1:
|
|
sim = self.calculate_robust_similarity(f1["feat_structure"], f2["feat_structure"], kernel_size)
|
|
dist = 1.0 - sim
|
|
total_dist += dist.mean(dim=[1, 2]) * w_struct
|
|
|
|
return total_dist
|
|
|
|
# =========================================================================
|
|
# PART 2: 增强后处理逻辑 (Post-Processing)
|
|
# 这一部分不在 DiffSim 官方库中,是为了实际工程落地增加的去噪模块
|
|
# =========================================================================
|
|
|
|
def engineering_post_process(heatmap_full, img_bg, args, patch_size):
|
|
h, w = heatmap_full.shape
|
|
|
|
# 1. 动态范围归一化
|
|
# 避免最大值过小(纯净背景)时,强制放大噪点
|
|
local_max = heatmap_full.max()
|
|
safe_max = max(local_max, 0.25) # 设定一个基准置信度,低于此值不拉伸
|
|
|
|
heatmap_norm = (heatmap_full / safe_max * 255).clip(0, 255).astype(np.uint8)
|
|
|
|
# 保存原始数据供调试
|
|
cv2.imwrite("debug_raw_heatmap.png", heatmap_norm)
|
|
|
|
# 2. 高斯滤波 (去散斑)
|
|
heatmap_blur = cv2.GaussianBlur(heatmap_norm, (5, 5), 0)
|
|
|
|
# 3. 阈值截断 (Hard Thresholding)
|
|
_, binary = cv2.threshold(heatmap_blur, int(255 * args.thresh), 255, cv2.THRESH_BINARY)
|
|
|
|
# 4. 形态学闭运算 (Merging)
|
|
# 将破碎的邻近区域融合为一个整体
|
|
kernel_morph = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7))
|
|
binary_closed = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel_morph)
|
|
|
|
# 5. 可视化绘制
|
|
heatmap_color = cv2.applyColorMap(heatmap_norm, cv2.COLORMAP_JET)
|
|
result_img = cv2.addWeighted(img_bg, 0.4, heatmap_color, 0.6, 0)
|
|
|
|
contours, _ = cv2.findContours(binary_closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
box_count = 0
|
|
# 面积过滤: 忽略小于切片面积 3% 的噪点
|
|
min_area = (patch_size ** 2) * 0.03
|
|
|
|
for cnt in contours:
|
|
area = cv2.contourArea(cnt)
|
|
if area > min_area:
|
|
box_count += 1
|
|
x, y, bw, bh = cv2.boundingRect(cnt)
|
|
|
|
# 绘制:白边 + 红框
|
|
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4)
|
|
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2)
|
|
|
|
# 分数标签
|
|
score_val = heatmap_full[y:y+bh, x:x+bw].mean()
|
|
label = f"{score_val:.2f}"
|
|
|
|
# 标签背景
|
|
cv2.rectangle(result_img, (x, y-22), (x+55, y), (0,0,255), -1)
|
|
cv2.putText(result_img, label, (x+5, y-6), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)
|
|
|
|
return result_img, box_count
|
|
|
|
# =========================================================================
|
|
# PART 3: 执行脚本
|
|
# =========================================================================
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="DiffSim Local Implementation")
|
|
parser.add_argument("t1", help="Reference Image")
|
|
parser.add_argument("t2", help="Query Image")
|
|
parser.add_argument("out", default="result.jpg")
|
|
|
|
# DiffSim 官方推荐参数
|
|
parser.add_argument("--w_struct", type=float, default=0.4)
|
|
parser.add_argument("--w_sem", type=float, default=0.6)
|
|
parser.add_argument("--kernel", type=int, default=3, help="Robust Kernel Size (1, 3, 5)")
|
|
|
|
# 工程化参数
|
|
parser.add_argument("--gamma", type=float, default=1.0)
|
|
parser.add_argument("--thresh", type=float, default=0.3)
|
|
parser.add_argument("-c", "--crop", type=int, default=224)
|
|
parser.add_argument("-b", "--batch", type=int, default=16)
|
|
parser.add_argument("--model", default="Manojb/stable-diffusion-2-1-base")
|
|
|
|
# 兼容性冗余参数
|
|
parser.add_argument("--step", type=int, default=0)
|
|
parser.add_argument("--w_tex", type=float, default=0.0)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# 1. Image IO
|
|
t1 = cv2.imread(args.t1)
|
|
t2 = cv2.imread(args.t2)
|
|
if t1 is None or t2 is None:
|
|
print("❌ Error reading images.")
|
|
return
|
|
|
|
# Resize to match T2
|
|
h, w = t2.shape[:2]
|
|
t1 = cv2.resize(t1, (w, h))
|
|
|
|
# 2. Preprocessing
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5], [0.5])
|
|
])
|
|
|
|
patches1, patches2, coords = [], [], []
|
|
stride = args.crop // 2 # 50% Overlap
|
|
|
|
print(f"🔪 Slicing images ({w}x{h}) with stride {stride}...")
|
|
for y in range(0, h - args.crop + 1, stride):
|
|
for x in range(0, w - args.crop + 1, stride):
|
|
c1 = t1[y:y+args.crop, x:x+args.crop]
|
|
c2 = t2[y:y+args.crop, x:x+args.crop]
|
|
|
|
p1 = transform(Image.fromarray(cv2.cvtColor(c1, cv2.COLOR_BGR2RGB)))
|
|
p2 = transform(Image.fromarray(cv2.cvtColor(c2, cv2.COLOR_BGR2RGB)))
|
|
|
|
patches1.append(p1); patches2.append(p2); coords.append((x, y))
|
|
|
|
if not patches1: return
|
|
|
|
# 3. Model Inference
|
|
model = DiffSim(args.model)
|
|
scores = []
|
|
|
|
print(f"🧠 Running DiffSim Inference on {len(patches1)} patches...")
|
|
with torch.no_grad():
|
|
for i in tqdm(range(0, len(patches1), args.batch)):
|
|
b1 = torch.stack(patches1[i:i+args.batch]).to("cuda", dtype=torch.float16)
|
|
b2 = torch.stack(patches2[i:i+args.batch]).to("cuda", dtype=torch.float16)
|
|
|
|
batch_dist = model(b1, b2, args.w_struct, args.w_sem, args.kernel)
|
|
scores.append(batch_dist.cpu())
|
|
|
|
all_scores = torch.cat(scores).float().numpy()
|
|
|
|
# 4. Reconstruct Heatmap
|
|
heatmap_full = np.zeros((h, w), dtype=np.float32)
|
|
count_map = np.zeros((h, w), dtype=np.float32) + 1e-6
|
|
|
|
# Apply Gamma *before* merging
|
|
if args.gamma != 1.0:
|
|
all_scores = np.power(all_scores, args.gamma)
|
|
|
|
for idx, score in enumerate(all_scores):
|
|
x, y = coords[idx]
|
|
heatmap_full[y:y+args.crop, x:x+args.crop] += score
|
|
count_map[y:y+args.crop, x:x+args.crop] += 1
|
|
|
|
heatmap_avg = heatmap_full / count_map
|
|
|
|
# 5. Post-Processing & Draw
|
|
print("🎨 Post-processing results...")
|
|
final_img, count = engineering_post_process(heatmap_avg, t2, args, args.crop)
|
|
|
|
cv2.imwrite(args.out, final_img)
|
|
print(f"✅ Done! Found {count} regions. Saved to {args.out}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|