puruan/main_plus.py
2026-02-04 09:54:24 +08:00

205 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
# 🔥 强制设置 HF 镜像
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
import argparse
from PIL import Image
from torchvision import transforms
from diffusers import StableDiffusionPipeline
from tqdm import tqdm
# === 配置 ===
MODEL_ID = "Manojb/stable-diffusion-2-1-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
THRESHOLD = 0.40 # ⬆️ 稍微调高阈值,进一步过滤误报
IMG_RESIZE = 224
class DiffSimSemantic(nn.Module):
def __init__(self, device):
super().__init__()
print(f"🚀 [系统] 初始化 DiffSim (语义增强版)...")
self.pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16).to(device)
self.pipe.set_progress_bar_config(disable=True)
# 冻结参数
self.pipe.vae.requires_grad_(False)
self.pipe.unet.requires_grad_(False)
self.pipe.text_encoder.requires_grad_(False)
# 预计算空文本 Embedding
with torch.no_grad():
prompt = ""
text_inputs = self.pipe.tokenizer(
prompt,
padding="max_length",
max_length=self.pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
self.empty_text_embeds = self.pipe.text_encoder(text_input_ids)[0]
self.features = {}
# 🔥 修改 Hooks只抓取深层特征忽略浅层纹理
# up_blocks.1 (纹理层) -> ❌ 移除,太敏感,容易误报
# up_blocks.2 (结构层) -> ✅ 保留,判断形状
# up_blocks.3 (语义层) -> ✅ 核心,判断物体类别
for name, layer in self.pipe.unet.named_modules():
# 我们不再 Hook up_blocks.1,因为它对光照和纹理太敏感
if "up_blocks.2" in name and name.endswith("resnets.2"):
layer.register_forward_hook(self.get_hook("feat_structure"))
elif "up_blocks.3" in name and name.endswith("resnets.2"):
layer.register_forward_hook(self.get_hook("feat_semantic"))
def get_hook(self, name):
def hook(model, input, output):
self.features[name] = output
return hook
def extract_features(self, images):
latents = self.pipe.vae.encode(images).latent_dist.sample() * self.pipe.vae.config.scaling_factor
batch_size = latents.shape[0]
t = torch.zeros(batch_size, device=DEVICE, dtype=torch.long)
encoder_hidden_states = self.empty_text_embeds.expand(batch_size, -1, -1)
self.pipe.unet(latents, t, encoder_hidden_states=encoder_hidden_states)
return {k: v.clone() for k, v in self.features.items()}
def robust_similarity(self, f1, f2, kernel_size=5):
"""
抗视差匹配:
🔥 将 kernel_size 默认值提升到 5
允许更大的几何错位应对30m高度的视差
"""
f1 = F.normalize(f1, dim=1)
f2 = F.normalize(f2, dim=1)
padding = kernel_size // 2
b, c, h, w = f2.shape
f2_unfolded = F.unfold(f2, kernel_size=kernel_size, padding=padding)
f2_unfolded = f2_unfolded.view(b, c, kernel_size*kernel_size, h, w)
sim_map = (f1.unsqueeze(2) * f2_unfolded).sum(dim=1)
max_sim, _ = sim_map.max(dim=1)
return max_sim
def compute_batch_distance(self, batch_p1, batch_p2):
feat_a = self.extract_features(batch_p1)
feat_b = self.extract_features(batch_p2)
total_score = 0
# 🔥 调整后的权重策略:纯粹关注结构和语义
# 0.0 -> 纹理 (彻底忽略颜色深浅、阴影)
# 0.4 -> 结构 (feat_structure): 关注形状变化
# 0.6 -> 语义 (feat_semantic): 关注物体存在性 (最像 DreamSim 的部分)
weights = {"feat_structure": 0.4, "feat_semantic": 0.6}
for name, w in weights.items():
fa, fb = feat_a[name].float(), feat_b[name].float()
# 对所有层都启用抗视差匹配,增加鲁棒性
# kernel_size=5 能容忍更大的像素位移
sim_map = self.robust_similarity(fa, fb, kernel_size=5)
dist = 1 - sim_map.mean(dim=[1, 2])
total_score += dist * w
return total_score
# ==========================================
# 辅助函数 (保持不变)
# ==========================================
def get_transforms():
return transforms.Compose([
transforms.Resize((IMG_RESIZE, IMG_RESIZE)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
def scan_and_draw(model, t1_path, t2_path, output_path, patch_size, stride, batch_size):
img1_cv = cv2.imread(t1_path)
img2_cv = cv2.imread(t2_path)
if img1_cv is None or img2_cv is None: return
h, w = img2_cv.shape[:2]
img1_cv = cv2.resize(img1_cv, (w, h))
preprocess = get_transforms()
print(f"🔪 [切片] 开始扫描... 尺寸: {w}x{h}, 忽略纹理细节,专注语义差异")
patches1, patches2, coords = [], [], []
for y in range(0, h - patch_size + 1, stride):
for x in range(0, w - patch_size + 1, stride):
crop1 = img1_cv[y:y+patch_size, x:x+patch_size]
crop2 = img2_cv[y:y+patch_size, x:x+patch_size]
p1 = preprocess(Image.fromarray(cv2.cvtColor(crop1, cv2.COLOR_BGR2RGB)))
p2 = preprocess(Image.fromarray(cv2.cvtColor(crop2, cv2.COLOR_BGR2RGB)))
patches1.append(p1); patches2.append(p2); coords.append((x, y))
if not patches1: return
all_distances = []
for i in tqdm(range(0, len(patches1), batch_size), unit="batch"):
b1 = torch.stack(patches1[i:i+batch_size]).to(DEVICE, dtype=torch.float16)
b2 = torch.stack(patches2[i:i+batch_size]).to(DEVICE, dtype=torch.float16)
with torch.no_grad():
all_distances.append(model.compute_batch_distance(b1, b2).cpu())
distances = torch.cat(all_distances)
heatmap = np.zeros((h, w), dtype=np.float32)
count_map = np.zeros((h, w), dtype=np.float32)
max_score = 0
for idx, score in enumerate(distances):
val = score.item()
x, y = coords[idx]
if val > max_score: max_score = val
heatmap[y:y+patch_size, x:x+patch_size] += val
count_map[y:y+patch_size, x:x+patch_size] += 1
count_map[count_map == 0] = 1
heatmap_avg = heatmap / count_map
# 可视化
norm_factor = max(max_score, 0.1)
heatmap_vis = (heatmap_avg / norm_factor * 255).clip(0, 255).astype(np.uint8)
heatmap_color = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET)
blended_img = cv2.addWeighted(img2_cv, 0.4, heatmap_color, 0.6, 0)
_, thresh = cv2.threshold(heatmap_vis, int(255 * THRESHOLD), 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
result_img = blended_img.copy()
for cnt in contours:
if cv2.contourArea(cnt) > (patch_size**2)*0.05:
x, y, bw, bh = cv2.boundingRect(cnt)
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4)
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2)
label = f"Diff: {heatmap_avg[y:y+bh, x:x+bw].mean():.2f}"
cv2.rectangle(result_img, (x, y-25), (x+130, y), (0,0,255), -1)
cv2.putText(result_img, label, (x+5, y-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
output_full_path = output_path if os.path.isabs(output_path) else os.path.join("/app/data", output_path)
os.makedirs(os.path.dirname(output_full_path), exist_ok=True)
cv2.imwrite(output_full_path, result_img)
print(f"🎯 完成! 最大差异分: {max_score:.4f}, 结果已保存: {output_full_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("t1"); parser.add_argument("t2"); parser.add_argument("out", nargs="?", default="result.jpg")
parser.add_argument("-c", "--crop", type=int, default=224)
parser.add_argument("-s", "--step", type=int, default=0)
parser.add_argument("-b", "--batch", type=int, default=16)
args = parser.parse_args()
scan_and_draw(DiffSimSemantic(DEVICE), args.t1, args.t2, args.out, args.crop, args.step if args.step>0 else args.crop//2, args.batch)