puruan/main.py
2026-02-04 09:54:24 +08:00

279 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
# 🔥 强制设置 HF 镜像 (必须放在最前面)
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
import argparse
from PIL import Image
from torchvision import transforms
from diffusers import StableDiffusionPipeline
from tqdm import tqdm
# === 配置 ===
# 使用 SD 1.5,无需鉴权,且对小切片纹理更敏感
MODEL_ID = "runwayml/stable-diffusion-v1-5"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
THRESHOLD = 0.35
IMG_RESIZE = 224
# ==========================================
# 🔥 核心DiffSim Pro 模型定义 (修复版)
# ==========================================
class DiffSimPro(nn.Module):
def __init__(self, device):
super().__init__()
print(f"🚀 [系统] 初始化 DiffSim Pro (基于 {MODEL_ID})...")
if device == "cuda":
print(f"✅ [硬件确认] 正在使用显卡: {torch.cuda.get_device_name(0)}")
else:
print("❌ [警告] 未检测到显卡,正在使用 CPU 慢速运行!")
# 1. 加载 SD 模型
self.pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16).to(device)
self.pipe.set_progress_bar_config(disable=True)
# 冻结参数
self.pipe.vae.requires_grad_(False)
self.pipe.unet.requires_grad_(False)
self.pipe.text_encoder.requires_grad_(False)
# 🔥【修复逻辑】:预先计算“空文本”的 Embedding
# UNet 必须要有这个 encoder_hidden_states 参数才能运行
with torch.no_grad():
prompt = ""
text_inputs = self.pipe.tokenizer(
prompt,
padding="max_length",
max_length=self.pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
# 获取空文本特征 [1, 77, 768]
self.empty_text_embeds = self.pipe.text_encoder(text_input_ids)[0]
# 2. 定义特征容器和 Hooks
self.features = {}
# 注册 Hooks抓取 纹理(1)、结构(2)、语义(3)
for name, layer in self.pipe.unet.named_modules():
if "up_blocks.1" in name and name.endswith("resnets.2"):
layer.register_forward_hook(self.get_hook("feat_high"))
elif "up_blocks.2" in name and name.endswith("resnets.2"):
layer.register_forward_hook(self.get_hook("feat_mid"))
elif "up_blocks.3" in name and name.endswith("resnets.2"):
layer.register_forward_hook(self.get_hook("feat_low"))
def get_hook(self, name):
def hook(model, input, output):
self.features[name] = output
return hook
def extract_features(self, images):
""" VAE Encode -> UNet Forward -> Hook Features """
# 1. VAE 编码
latents = self.pipe.vae.encode(images).latent_dist.sample() * self.pipe.vae.config.scaling_factor
# 2. 准备参数
batch_size = latents.shape[0]
t = torch.zeros(batch_size, device=DEVICE, dtype=torch.long)
# 🔥【修复逻辑】:将空文本 Embedding 扩展到当前 Batch 大小
# 形状变为 [batch_size, 77, 768]
encoder_hidden_states = self.empty_text_embeds.expand(batch_size, -1, -1)
# 3. UNet 前向传播 (带上 encoder_hidden_states)
self.pipe.unet(latents, t, encoder_hidden_states=encoder_hidden_states)
return {k: v.clone() for k, v in self.features.items()}
def robust_similarity(self, f1, f2, kernel_size=3):
""" 抗视差匹配算法 """
f1 = F.normalize(f1, dim=1)
f2 = F.normalize(f2, dim=1)
padding = kernel_size // 2
b, c, h, w = f2.shape
f2_unfolded = F.unfold(f2, kernel_size=kernel_size, padding=padding)
f2_unfolded = f2_unfolded.view(b, c, kernel_size*kernel_size, h, w)
sim_map = (f1.unsqueeze(2) * f2_unfolded).sum(dim=1)
max_sim, _ = sim_map.max(dim=1)
return max_sim
def compute_batch_distance(self, batch_p1, batch_p2):
feat_a = self.extract_features(batch_p1)
feat_b = self.extract_features(batch_p2)
total_score = 0
# 权重:结构层(mid)最重要
weights = {"feat_high": 0.2, "feat_mid": 0.5, "feat_low": 0.3}
for name, w in weights.items():
fa, fb = feat_a[name].float(), feat_b[name].float()
if name == "feat_high":
sim_map = self.robust_similarity(fa, fb, kernel_size=3)
dist = 1 - sim_map.mean(dim=[1, 2])
else:
dist = 1 - F.cosine_similarity(fa.flatten(1), fb.flatten(1))
total_score += dist * w
return total_score
# ==========================================
# 🛠️ 辅助函数 & 扫描逻辑 (保持不变)
# ==========================================
def get_transforms():
return transforms.Compose([
transforms.Resize((IMG_RESIZE, IMG_RESIZE)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
def scan_and_draw(model, t1_path, t2_path, output_path, patch_size, stride, batch_size):
# 1. OpenCV 读取
img1_cv = cv2.imread(t1_path)
img2_cv = cv2.imread(t2_path)
if img1_cv is None or img2_cv is None:
print("❌ 错误: 无法读取图片")
return
# 强制 Resize 对齐
h, w = img2_cv.shape[:2]
img1_cv = cv2.resize(img1_cv, (w, h))
preprocess = get_transforms()
# 2. 准备滑动窗口
print(f"🔪 [切片] 开始扫描... 尺寸: {w}x{h}")
print(f" - 切片大小: {patch_size}, 步长: {stride}, 批次: {batch_size}")
patches1 = []
patches2 = []
coords = []
for y in range(0, h - patch_size + 1, stride):
for x in range(0, w - patch_size + 1, stride):
crop1 = img1_cv[y:y+patch_size, x:x+patch_size]
crop2 = img2_cv[y:y+patch_size, x:x+patch_size]
p1 = preprocess(Image.fromarray(cv2.cvtColor(crop1, cv2.COLOR_BGR2RGB)))
p2 = preprocess(Image.fromarray(cv2.cvtColor(crop2, cv2.COLOR_BGR2RGB)))
patches1.append(p1)
patches2.append(p2)
coords.append((x, y))
if not patches1:
print("⚠️ 图片太小,无法切片")
return
total_patches = len(patches1)
print(f"🧠 [推理] 共 {total_patches} 个切片,开始 DiffSim Pro 计算...")
all_distances = []
# 3. 批量推理
for i in tqdm(range(0, total_patches, batch_size), unit="batch"):
batch_p1 = torch.stack(patches1[i : i + batch_size]).to(DEVICE, dtype=torch.float16)
batch_p2 = torch.stack(patches2[i : i + batch_size]).to(DEVICE, dtype=torch.float16)
with torch.no_grad():
dist_batch = model.compute_batch_distance(batch_p1, batch_p2)
all_distances.append(dist_batch.cpu())
distances = torch.cat(all_distances)
# 4. 生成原始热力数据
heatmap = np.zeros((h, w), dtype=np.float32)
count_map = np.zeros((h, w), dtype=np.float32)
max_score = 0
for idx, score in enumerate(distances):
val = score.item()
x, y = coords[idx]
if val > max_score: max_score = val
heatmap[y:y+patch_size, x:x+patch_size] += val
count_map[y:y+patch_size, x:x+patch_size] += 1
count_map[count_map == 0] = 1
heatmap_avg = heatmap / count_map
# 5. 后处理
norm_factor = max(max_score, 0.1)
heatmap_vis = (heatmap_avg / norm_factor * 255).clip(0, 255).astype(np.uint8)
heatmap_color = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET)
alpha = 0.4
beta = 1.0 - alpha
blended_img = cv2.addWeighted(img2_cv, alpha, heatmap_color, beta, 0)
# 画框
_, thresh = cv2.threshold(heatmap_vis, int(255 * THRESHOLD), 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
result_img = blended_img.copy()
found_issue = False
for cnt in contours:
area = cv2.contourArea(cnt)
min_area = (patch_size * patch_size) * 0.05
if area > min_area:
found_issue = True
x, y, bw, bh = cv2.boundingRect(cnt)
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (255, 255, 255), 4)
cv2.rectangle(result_img, (x, y), (x+bw, y+bh), (0, 0, 255), 2)
roi_score = heatmap_avg[y:y+bh, x:x+bw].mean()
label = f"Diff: {roi_score:.2f}"
cv2.rectangle(result_img, (x, y-25), (x+130, y), (0,0,255), -1)
cv2.putText(result_img, label, (x+5, y-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
output_full_path = output_path
if not os.path.isabs(output_path) and not output_path.startswith("."):
output_full_path = os.path.join("/app/data", output_path)
os.makedirs(os.path.dirname(output_full_path) if os.path.dirname(output_full_path) else ".", exist_ok=True)
cv2.imwrite(output_full_path, result_img)
print("="*40)
print(f"🎯 扫描完成! 最大差异分: {max_score:.4f}")
if found_issue:
print(f"⚠️ 警告: 检测到潜在违建区域!")
print(f"🖼️ 热力图结果已保存至: {output_full_path}")
print("="*40)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="DiffSim Pro 违建检测 (抗视差版)")
parser.add_argument("t1", help="基准图路径")
parser.add_argument("t2", help="现状图路径")
parser.add_argument("out", nargs="?", default="heatmap_diffsim.jpg", help="输出文件名")
parser.add_argument("-c", "--crop", type=int, default=224, help="切片大小")
parser.add_argument("-s", "--step", type=int, default=0, help="滑动步长")
parser.add_argument("-b", "--batch", type=int, default=16, help="批次大小")
args = parser.parse_args()
stride = args.step if args.step > 0 else args.crop // 2
# 初始化模型
diffsim_model = DiffSimPro(DEVICE)
print(f"📂 启动热力图扫描: {args.t1} vs {args.t2}")
scan_and_draw(diffsim_model, args.t1, args.t2, args.out, args.crop, stride, args.batch)