279 lines
10 KiB
Python
279 lines
10 KiB
Python
import os
|
||
# 🔥 强制设置 HF 镜像 (必须放在最前面)
|
||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
||
|
||
import sys
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import cv2
|
||
import numpy as np
|
||
import argparse
|
||
from PIL import Image
|
||
from torchvision import transforms
|
||
from diffusers import StableDiffusionPipeline
|
||
from tqdm import tqdm
|
||
|
||
# === 配置 ===
|
||
# 使用 SD 1.5,无需鉴权,且对小切片纹理更敏感
|
||
MODEL_ID = "runwayml/stable-diffusion-v1-5"
|
||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||
THRESHOLD = 0.35
|
||
IMG_RESIZE = 224
|
||
|
||
# ==========================================
|
||
# 🔥 核心:DiffSim Pro 模型定义 (修复版)
|
||
# ==========================================
|
||
class DiffSimPro(nn.Module):
|
||
def __init__(self, device):
|
||
super().__init__()
|
||
print(f"🚀 [系统] 初始化 DiffSim Pro (基于 {MODEL_ID})...")
|
||
|
||
if device == "cuda":
|
||
print(f"✅ [硬件确认] 正在使用显卡: {torch.cuda.get_device_name(0)}")
|
||
else:
|
||
print("❌ [警告] 未检测到显卡,正在使用 CPU 慢速运行!")
|
||
|
||
# 1. 加载 SD 模型
|
||
self.pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16).to(device)
|
||
self.pipe.set_progress_bar_config(disable=True)
|
||
|
||
# 冻结参数
|
||
self.pipe.vae.requires_grad_(False)
|
||
self.pipe.unet.requires_grad_(False)
|
||
self.pipe.text_encoder.requires_grad_(False)
|
||
|
||
# 🔥【修复逻辑】:预先计算“空文本”的 Embedding
|
||
# UNet 必须要有这个 encoder_hidden_states 参数才能运行
|
||
with torch.no_grad():
|
||
prompt = ""
|
||
text_inputs = self.pipe.tokenizer(
|
||
prompt,
|
||
padding="max_length",
|
||
max_length=self.pipe.tokenizer.model_max_length,
|
||
truncation=True,
|
||
return_tensors="pt",
|
||
)
|
||
text_input_ids = text_inputs.input_ids.to(device)
|
||
# 获取空文本特征 [1, 77, 768]
|
||
self.empty_text_embeds = self.pipe.text_encoder(text_input_ids)[0]
|
||
|
||
# 2. 定义特征容器和 Hooks
|
||
self.features = {}
|
||
|
||
# 注册 Hooks:抓取 纹理(1)、结构(2)、语义(3)
|
||
for name, layer in self.pipe.unet.named_modules():
|
||
if "up_blocks.1" in name and name.endswith("resnets.2"):
|
||
layer.register_forward_hook(self.get_hook("feat_high"))
|
||
elif "up_blocks.2" in name and name.endswith("resnets.2"):
|
||
layer.register_forward_hook(self.get_hook("feat_mid"))
|
||
elif "up_blocks.3" in name and name.endswith("resnets.2"):
|
||
layer.register_forward_hook(self.get_hook("feat_low"))
|
||
|
||
def get_hook(self, name):
|
||
def hook(model, input, output):
|
||
self.features[name] = output
|
||
return hook
|
||
|
||
def extract_features(self, images):
|
||
""" VAE Encode -> UNet Forward -> Hook Features """
|
||
# 1. VAE 编码
|
||
latents = self.pipe.vae.encode(images).latent_dist.sample() * self.pipe.vae.config.scaling_factor
|
||
|
||
# 2. 准备参数
|
||
batch_size = latents.shape[0]
|
||
t = torch.zeros(batch_size, device=DEVICE, dtype=torch.long)
|
||
|
||
# 🔥【修复逻辑】:将空文本 Embedding 扩展到当前 Batch 大小
|
||
# 形状变为 [batch_size, 77, 768]
|
||
encoder_hidden_states = self.empty_text_embeds.expand(batch_size, -1, -1)
|
||
|
||
# 3. UNet 前向传播 (带上 encoder_hidden_states)
|
||
self.pipe.unet(latents, t, encoder_hidden_states=encoder_hidden_states)
|
||
|
||
return {k: v.clone() for k, v in self.features.items()}
|
||
|
||
def robust_similarity(self, f1, f2, kernel_size=3):
|
||
""" 抗视差匹配算法 """
|
||
f1 = F.normalize(f1, dim=1)
|
||
f2 = F.normalize(f2, dim=1)
|
||
|
||
padding = kernel_size // 2
|
||
b, c, h, w = f2.shape
|
||
|
||
f2_unfolded = F.unfold(f2, kernel_size=kernel_size, padding=padding)
|
||
f2_unfolded = f2_unfolded.view(b, c, kernel_size*kernel_size, h, w)
|
||
|
||
sim_map = (f1.unsqueeze(2) * f2_unfolded).sum(dim=1)
|
||
max_sim, _ = sim_map.max(dim=1)
|
||
|
||
return max_sim
|
||
|
||
def compute_batch_distance(self, batch_p1, batch_p2):
|
||
feat_a = self.extract_features(batch_p1)
|
||
feat_b = self.extract_features(batch_p2)
|
||
|
||
total_score = 0
|
||
# 权重:结构层(mid)最重要
|
||
weights = {"feat_high": 0.2, "feat_mid": 0.5, "feat_low": 0.3}
|
||
|
||
for name, w in weights.items():
|
||
fa, fb = feat_a[name].float(), feat_b[name].float()
|
||
|
||
if name == "feat_high":
|
||
sim_map = self.robust_similarity(fa, fb, kernel_size=3)
|
||
dist = 1 - sim_map.mean(dim=[1, 2])
|
||
else:
|
||
dist = 1 - F.cosine_similarity(fa.flatten(1), fb.flatten(1))
|
||
|
||
total_score += dist * w
|
||
|
||
return total_score
|
||
|
||
# ==========================================
|
||
# 🛠️ 辅助函数 & 扫描逻辑 (保持不变)
|
||
# ==========================================
|
||
|
||
def get_transforms():
|
||
return transforms.Compose([
|
||
transforms.Resize((IMG_RESIZE, IMG_RESIZE)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.5], [0.5])
|
||
])
|
||
|
||
def scan_and_draw(model, t1_path, t2_path, output_path, patch_size, stride, batch_size):
|
||
# 1. OpenCV 读取
|
||
img1_cv = cv2.imread(t1_path)
|
||
img2_cv = cv2.imread(t2_path)
|
||
|
||
if img1_cv is None or img2_cv is None:
|
||
print("❌ 错误: 无法读取图片")
|
||
return
|
||
|
||
# 强制 Resize 对齐
|
||
h, w = img2_cv.shape[:2]
|
||
img1_cv = cv2.resize(img1_cv, (w, h))
|
||
|
||
preprocess = get_transforms()
|
||
|
||
# 2. 准备滑动窗口
|
||
print(f"🔪 [切片] 开始扫描... 尺寸: {w}x{h}")
|
||
print(f" - 切片大小: {patch_size}, 步长: {stride}, 批次: {batch_size}")
|
||
|
||
patches1 = []
|
||
patches2 = []
|
||
coords = []
|
||
|
||
for y in range(0, h - patch_size + 1, stride):
|
||
for x in range(0, w - patch_size + 1, stride):
|
||
crop1 = img1_cv[y:y+patch_size, x:x+patch_size]
|
||
crop2 = img2_cv[y:y+patch_size, x:x+patch_size]
|
||
|
||
p1 = preprocess(Image.fromarray(cv2.cvtColor(crop1, cv2.COLOR_BGR2RGB)))
|
||
p2 = preprocess(Image.fromarray(cv2.cvtColor(crop2, cv2.COLOR_BGR2RGB)))
|
||
|
||
patches1.append(p1)
|
||
patches2.append(p2)
|
||
coords.append((x, y))
|
||
|
||
if not patches1:
|
||
print("⚠️ 图片太小,无法切片")
|
||
return
|
||
|
||
total_patches = len(patches1)
|
||
print(f"🧠 [推理] 共 {total_patches} 个切片,开始 DiffSim Pro 计算...")
|
||
|
||
all_distances = []
|
||
|
||
# 3. 批量推理
|
||
for i in tqdm(range(0, total_patches, batch_size), unit="batch"):
|
||
batch_p1 = torch.stack(patches1[i : i + batch_size]).to(DEVICE, dtype=torch.float16)
|
||
batch_p2 = torch.stack(patches2[i : i + batch_size]).to(DEVICE, dtype=torch.float16)
|
||
|
||
with torch.no_grad():
|
||
dist_batch = model.compute_batch_distance(batch_p1, batch_p2)
|
||
all_distances.append(dist_batch.cpu())
|
||
|
||
distances = torch.cat(all_distances)
|
||
|
||
# 4. 生成原始热力数据
|
||
heatmap = np.zeros((h, w), dtype=np.float32)
|
||
count_map = np.zeros((h, w), dtype=np.float32)
|
||
max_score = 0
|
||
|
||
for idx, score in enumerate(distances):
|
||
val = score.item()
|
||
x, y = coords[idx]
|
||
if val > max_score: max_score = val
|
||
|
||
heatmap[y:y+patch_size, x:x+patch_size] += val
|
||
count_map[y:y+patch_size, x:x+patch_size] += 1
|
||
|
||
count_map[count_map == 0] = 1
|
||
heatmap_avg = heatmap / count_map
|
||
|
||
# 5. 后处理
|
||
norm_factor = max(max_score, 0.1)
|
||
heatmap_vis = (heatmap_avg / norm_factor * 255).clip(0, 255).astype(np.uint8)
|
||
heatmap_color = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET)
|
||
|
||
alpha = 0.4
|
||
beta = 1.0 - alpha
|
||
blended_img = cv2.addWeighted(img2_cv, alpha, heatmap_color, beta, 0)
|
||
|
||
# 画框
|
||
_, thresh = cv2.threshold(heatmap_vis, int(255 * THRESHOLD), 255, cv2.THRESH_BINARY)
|
||
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||
|
||
result_img = blended_img.copy()
|
||
found_issue = False
|
||
|
||
for cnt in contours:
|
||
area = cv2.contourArea(cnt)
|
||
min_area = (patch_size * patch_size) * 0.05
|
||
|
||
if area > min_area:
|
||
found_issue = True
|
||
x, y, bw, bh = cv2.boundingRect(cnt)
|
||
|
||
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4)
|
||
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2)
|
||
|
||
roi_score = heatmap_avg[y:y+bh, x:x+bw].mean()
|
||
label = f"Diff: {roi_score:.2f}"
|
||
|
||
cv2.rectangle(result_img, (x, y-25), (x+130, y), (0,0,255), -1)
|
||
cv2.putText(result_img, label, (x+5, y-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
|
||
|
||
output_full_path = output_path
|
||
if not os.path.isabs(output_path) and not output_path.startswith("."):
|
||
output_full_path = os.path.join("/app/data", output_path)
|
||
os.makedirs(os.path.dirname(output_full_path) if os.path.dirname(output_full_path) else ".", exist_ok=True)
|
||
|
||
cv2.imwrite(output_full_path, result_img)
|
||
|
||
print("="*40)
|
||
print(f"🎯 扫描完成! 最大差异分: {max_score:.4f}")
|
||
if found_issue:
|
||
print(f"⚠️ 警告: 检测到潜在违建区域!")
|
||
print(f"🖼️ 热力图结果已保存至: {output_full_path}")
|
||
print("="*40)
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="DiffSim Pro 违建检测 (抗视差版)")
|
||
parser.add_argument("t1", help="基准图路径")
|
||
parser.add_argument("t2", help="现状图路径")
|
||
parser.add_argument("out", nargs="?", default="heatmap_diffsim.jpg", help="输出文件名")
|
||
parser.add_argument("-c", "--crop", type=int, default=224, help="切片大小")
|
||
parser.add_argument("-s", "--step", type=int, default=0, help="滑动步长")
|
||
parser.add_argument("-b", "--batch", type=int, default=16, help="批次大小")
|
||
|
||
args = parser.parse_args()
|
||
stride = args.step if args.step > 0 else args.crop // 2
|
||
|
||
# 初始化模型
|
||
diffsim_model = DiffSimPro(DEVICE)
|
||
|
||
print(f"📂 启动热力图扫描: {args.t1} vs {args.t2}")
|
||
scan_and_draw(diffsim_model, args.t1, args.t2, args.out, args.crop, stride, args.batch)
|