上传一些测试文件

This commit is contained in:
fuyn101 2025-05-23 19:36:44 +08:00
parent cee44523c6
commit 44782a3117
10 changed files with 912 additions and 689 deletions

BIN
depot_test/output/matches_knn_only.png (Stored with Git LFS) Normal file

Binary file not shown.

BIN
depot_test/stitched_image_multi.png (Stored with Git LFS) Normal file

Binary file not shown.

View file

@ -0,0 +1,324 @@
import glob
import os
import re # 用于解析文件名中的标签
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm # 用于显示进度条
# --- 配置参数 ---
TRAIN_DIR = "训练集"
VAL_TEST_DIR = "测试集" # 根据你的描述,验证集和测试集是同一个目录,文件格式相同
IMAGE_SIZE = 224 # MobileNetV3 的标准输入大小
BATCH_SIZE = 32
NUM_EPOCHS = 20 # 可以根据需要调整
LEARNING_RATE = 0.001 # 初始学习率
SAVE_MODEL_PATH = "mobilenetv3_small_finetuned.pth"
# 自动检测设备
device = torch.device("cuda")
print(f"使用设备: {device}")
# --- 数据预处理和增强 ---
# ImageNet 标准均值和标准差
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# 训练集数据增强和预处理
train_transforms = transforms.Compose(
[
transforms.RandomResizedCrop(IMAGE_SIZE), # 随机裁剪并缩放
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(), # 转换为 Tensor
transforms.Normalize(mean, std), # 标准化
]
)
# 验证/测试集数据预处理 (不需要数据增强,只需要中心裁剪和标准化)
val_test_transforms = transforms.Compose(
[
transforms.Resize(int(IMAGE_SIZE * 256 / 224)), # 缩放到较大尺寸
transforms.CenterCrop(IMAGE_SIZE), # 中心裁剪到目标尺寸
transforms.ToTensor(), # 转换为 Tensor
transforms.Normalize(mean, std), # 标准化
]
)
# --- 自定义测试集 Dataset ---
# 需要一个自定义 Dataset 来处理 "测试集/{idx}_{label}.png" 这种文件命名格式
class CustomValTestDataset(Dataset):
def __init__(self, root_dir, class_to_idx, transform=None):
"""
Args:
root_dir (string): 数据集根目录 (e.g., "测试集").
class_to_idx (dict): 从类别名称到索引的映射与训练集一致
transform (callable, optional): 应用于图像的转换.
"""
self.root_dir = root_dir
self.transform = transform
self.class_to_idx = class_to_idx
self.idx_to_class = {v: k for k, v in class_to_idx.items()}
self.image_files = []
self.labels = []
# 遍历目录下的所有png文件
filepaths = glob.glob(os.path.join(root_dir, "*.png"))
# 解析文件名,提取标签
pattern = re.compile(r"^\d+_([^_]+)\.png$") # 匹配 数字_标签.png
for filepath in filepaths:
filename = os.path.basename(filepath)
match = pattern.match(filename)
if match:
label_name = match.group(1)
if label_name in self.class_to_idx:
self.image_files.append(filepath)
self.labels.append(self.class_to_idx[label_name])
else:
print(
f"警告: 文件 '{filename}' 中的标签 '{label_name}' 不在训练集的类别中,将跳过。"
)
print(f"加载了 {len(self.image_files)} 张验证/测试图片。")
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_path = self.image_files[idx]
label = self.labels[idx]
# 打开图像,确保是 RGB (处理可能的灰度图)
img = Image.open(img_path).convert("RGB")
if self.transform:
img = self.transform(img)
return img, label
# --- 加载数据 ---
# 使用 ImageFolder 加载训练集,它会自动从目录名解析类别
if not os.path.exists(TRAIN_DIR):
print(
f"错误: 训练集目录 '{TRAIN_DIR}' 不存在。请创建该目录并放入分类好的图片子目录。"
)
exit()
if not os.path.exists(VAL_TEST_DIR):
print(f"错误: 验证/测试集目录 '{VAL_TEST_DIR}' 不存在。请创建该目录并放入图片。")
exit()
train_dataset = ImageFolder(TRAIN_DIR, transform=train_transforms)
num_classes = len(train_dataset.classes)
class_to_idx = train_dataset.class_to_idx # 获取类别到索引的映射
print(f"从训练集检测到 {num_classes} 个类别: {train_dataset.classes}")
# 使用自定义 Dataset 加载验证/测试集
val_test_dataset = CustomValTestDataset(
VAL_TEST_DIR, class_to_idx, transform=val_test_transforms
)
# 创建 DataLoader
train_loader = DataLoader(
train_dataset, batch_size=BATCH_SIZE, shuffle=True
) # num_workers 根据你的机器性能调整
val_test_loader = DataLoader(val_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
# --- 加载预训练的 MobileNetV3-Small 模型 ---
# 使用 weights 参数来指定预训练权重
# MobileNetV3_Small_Weights.IMAGENET1K_V1 是在 ImageNet 上预训练的权重
try:
model = models.mobilenet_v3_small(
weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
)
print("成功加载预训练的 MobileNetV3-Small 模型 (ImageNet weights)。")
except Exception as e:
print(f"加载预训练模型失败: {e}")
print("尝试加载不带权重的模型...")
model = models.mobilenet_v3_small(weights=None)
# --- 修改全连接层以匹配新的类别数量 ---
# MobileNetV3 的分类器是 model.classifier
# 最后一个线性层是 classifier[-1]
num_ftrs = model.classifier[-1].in_features
# 替换掉原来的全连接层
model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
model = model.to(device)
# --- 定义损失函数和优化器 ---
criterion = nn.CrossEntropyLoss() # 交叉熵损失适用于分类问题
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) # Adam 优化器
# 可选:学习率调度器,帮助调整学习率
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) # 每7个epoch降低学习率
# --- 训练和评估函数 ---
def train_epoch(model, train_loader, criterion, optimizer, device):
model.train() # 设置模型为训练模式
running_loss = 0.0
correct_predictions = 0
total_samples = 0
# 使用 tqdm 显示进度条
for inputs, labels in tqdm(train_loader, desc="训练中"):
inputs, labels = inputs.to(device), labels.to(device)
# 梯度清零
optimizer.zero_grad()
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
loss.backward()
optimizer.step()
# 统计
running_loss += loss.item() * inputs.size(0) # 累加 batch loss * batch size
_, predicted = torch.max(outputs, 1) # 获取预测结果
correct_predictions += (predicted == labels).sum().item()
total_samples += labels.size(0)
epoch_loss = running_loss / total_samples
epoch_accuracy = correct_predictions / total_samples
return epoch_loss, epoch_accuracy
def evaluate(model, data_loader, criterion, device, desc="评估中"):
model.eval() # 设置模型为评估模式
running_loss = 0.0
correct_predictions = 0
total_samples = 0
# 在评估阶段不计算梯度
with torch.no_grad():
# 使用 tqdm 显示进度条
for inputs, labels in tqdm(data_loader, desc=desc):
inputs, labels = inputs.to(device), labels.to(device)
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 统计
running_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
correct_predictions += (predicted == labels).sum().item()
total_samples += labels.size(0)
epoch_loss = running_loss / total_samples
epoch_accuracy = correct_predictions / total_samples
return epoch_loss, epoch_accuracy
# --- 训练循环 ---
best_val_accuracy = 0.0
print("\n开始训练...")
for epoch in range(NUM_EPOCHS):
print(f"\n--- Epoch {epoch + 1}/{NUM_EPOCHS} ---")
# 训练阶段
train_loss, train_accuracy = train_epoch(
model, train_loader, criterion, optimizer, device
)
print(
f"Epoch {epoch + 1} 训练 Loss: {train_loss:.4f}, 准确率: {train_accuracy:.4f}"
)
# 可选:学习率调度
# if scheduler is not None:
# scheduler.step()
# 验证/测试阶段
val_loss, val_accuracy = evaluate(
model, val_test_loader, criterion, device, desc="验证/测试中"
)
print(
f"Epoch {epoch + 1} 验证/测试 Loss: {val_loss:.4f}, 准确率: {val_accuracy:.4f}"
)
# 保存最优模型
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
torch.save(model.state_dict(), SAVE_MODEL_PATH)
print(
f"保存了验证/测试集上最优的模型 (准确率: {best_val_accuracy:.4f}) 到 {SAVE_MODEL_PATH}"
)
print("\n训练完成!")
print(f"在验证/测试集上的最高准确率: {best_val_accuracy:.4f}")
# --- 可选: 加载并测试最优模型 ---
# 加载保存的最优模型进行最终测试 (如果验证/测试集是同一个,这就是最终结果)
print(f"\n加载最优模型 '{SAVE_MODEL_PATH}' 进行最终评估...")
loaded_model = models.mobilenet_v3_small(weights=None) # 先加载一个空的模型结构
loaded_model.classifier[-1] = nn.Linear(
loaded_model.classifier[-1].in_features, num_classes
) # 修改分类器
loaded_model.load_state_dict(
torch.load(SAVE_MODEL_PATH, map_location=device)
) # 加载权重
loaded_model = loaded_model.to(device)
final_test_loss, final_test_accuracy = evaluate(
loaded_model, val_test_loader, criterion, device, desc="最终测试中"
)
print(
f"\n最终测试 Loss: {final_test_loss:.4f}, 最终测试准确率: {final_test_accuracy:.4f}"
)
# --- 可选: 预测单张图片 ---
# 假设你想预测一张名为 '测试集/some_image_X_label.png' 的图片
def predict_single_image(image_path, model, class_to_idx, device, transform):
model.eval()
idx_to_class = {v: k for k, v in class_to_idx.items()}
try:
img = Image.open(image_path).convert("RGB")
img = transform(img).unsqueeze(0).to(device) # 添加 batch 维度并移动到设备
with torch.no_grad():
outputs = model(img)
probabilities = torch.softmax(outputs, dim=1)[0] # 获取概率分布
_, predicted_idx = torch.max(probabilities, 0) # 获取最高概率的索引
predicted_label = idx_to_class[predicted_idx.item()]
confidence = probabilities[predicted_idx].item()
print(f"\n预测图片: {image_path}")
print(f"预测类别: {predicted_label}, 置信度: {confidence:.4f}")
return predicted_label, confidence
except FileNotFoundError:
print(f"错误: 图片文件 '{image_path}' 未找到。")
return None, None
except Exception as e:
print(f"预测图片时发生错误: {e}")
return None, None
# # 示例预测 (取消注释以使用)
example_image_path = r"测试集\27_基础作战记录.png" # 替换为你测试集中的实际文件路径
predict_single_image(
example_image_path, loaded_model, class_to_idx, device, val_test_transforms
)

View file

@ -0,0 +1,378 @@
import json
import lzma
import os
import pickle
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image, ImageDraw, ImageFont, ImageOps
from sklearn.neighbors import KNeighborsClassifier
from torchvision import models
CROP_SIZE = 130
BORDER = 26
size = CROP_SIZE * 2 - BORDER * 2
# 定义特征提取器
model = models.mobilenet_v3_small(weights="DEFAULT")
features_part = model.features
avgpool = torch.nn.AdaptiveAvgPool2d(1)
classifier_part_excluding_last = torch.nn.Sequential(
*list(model.classifier.children())[:-1]
)
feature_extractor = torch.nn.Sequential(
features_part,
avgpool,
torch.nn.Flatten(start_dim=1),
classifier_part_excluding_last,
)
feature_extractor.eval() # 切换到评估模式
def 提取特征点(模板):
"""使用MobileNetV3提取特征 (PyTorch版)"""
# 将输入图像从BGR转换为RGB
img_rgb = cv2.cvtColor(模板, cv2.COLOR_BGR2RGB)
# 定义图像预处理流程
preprocess = transforms.Compose(
[
transforms.ToPILImage(), # 转换为PIL图像
transforms.Resize(250), # 调整大小为224x224
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize( # 归一化
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
# 预处理图像
img_tensor = preprocess(img_rgb)
img_tensor = img_tensor.unsqueeze(0) # 增加batch维度
# 提取特征
with torch.no_grad():
features = feature_extractor(img_tensor)
# 将特征展平为一维
features = features.flatten().numpy()
return features
class DepotMatcher:
def __init__(
self,
ref_table_json="./ArknightsGameData/zh_CN/gamedata/excel/item_table.json",
icon_dir="./ArknightsResource/items/",
ref_dir="depot_test/output/test/origin",
roi_dir="depot_test/output/test/result",
img_path=r"depot_test\stitched_image_multi.png",
):
# 初始化路径配置
self.REF_DIR = ref_dir
self.ROI_DIR = roi_dir
self.IMG_PATH = img_path
self.REF_TABLE_JSON = ref_table_json
self.ICON_DIR = icon_dir
# 初始化算法参数
self.HOUGH_PARAMS = dict(
dp=5, minDist=230, param1=50, param2=30, minRadius=90, maxRadius=100
)
self.CROP_SIZE = 130
self.BORDER = 26
# 运行时数据存储
self.refs = None
self.rois = []
self.knn_results = []
self.knn_model = None
def load_references(self):
"""加载物品图标参考图(保留彩色)"""
data = json.load(open(self.REF_TABLE_JSON, encoding="utf-8"))
self.refs = {}
size = self.CROP_SIZE * 2 - self.BORDER * 2
# 首先收集所有带有sortId的物品
items_with_sort = []
for item in data.get("items", {}).values():
if item.get("classifyType") not in {"NORMAL", "CONSUME", "MATERIAL"}:
continue
path = os.path.join(self.ICON_DIR, f"{item['iconId']}.png")
if not os.path.exists(path):
continue
# 保留彩色图像
im = Image.open(path).resize((size, size))
items_with_sort.append(
{
"name": item["name"],
"array": np.array(im),
"sortId": item.get("sortId", 0),
}
)
# 按sortId排序
items_with_sort.sort(key=lambda x: x["sortId"])
# 创建最终的refs字典
for item in items_with_sort:
self.refs[item["name"]] = item["array"]
print(f"已加载 {len(self.refs)} 个参考图 (按sortId排序)")
# 保存训练集图像
os.makedirs("训练集", exist_ok=True)
for name, array in self.refs.items():
os.makedirs(f"训练集/{name}", exist_ok=True)
path = os.path.join(f"训练集/{name}", f"{name}.png")
im = Image.fromarray(array)
cropped_im = im.crop((50, 30, 160, 140))
cropped_im.save(path)
return self
def _process_circle(self, idx, circle, img):
"""处理单个圆形区域,返回彩色图像数据"""
x, y, r = circle
# 裁剪包含圆形的更大区域
crop = img[
max(0, y - self.CROP_SIZE) : min(img.shape[0], y + self.CROP_SIZE),
max(0, x - self.CROP_SIZE) : min(img.shape[1], x + self.CROP_SIZE),
]
# 提取核心的彩色ROI区域
color_roi = crop[self.BORDER : -self.BORDER, self.BORDER : -self.BORDER]
# 提取用于匹配的彩色区域
color_sec = color_roi
return idx, color_sec, color_roi
def detect_and_crop(self):
"""检测并裁剪截图区域"""
img = cv2.imread(self.IMG_PATH)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, **self.HOUGH_PARAMS)
# 处理检测到的圆形
circles = np.round(circles[0]).astype(int)
circles = sorted(circles, key=lambda c: (c[0], c[1])) # 按坐标排序
self.rois = []
for idx, circle in enumerate(circles):
result = self._process_circle(idx, circle, img)
self.rois.append(result)
return self
def 训练并保存knn模型(self, images, labels, filename):
"""训练并保存KNN模型"""
knn_classifier = KNeighborsClassifier(
weights="distance", n_neighbors=1, n_jobs=1
)
knn_classifier.fit(images, labels)
with lzma.open(filename, "wb") as f:
pickle.dump(knn_classifier, f)
return knn_classifier
def 训练knn模型(self, 模型保存路径="depot_knn_model.xz"):
"""训练并保存KNN模型(使用彩色图像)"""
# 准备训练数据
images = []
labels = []
for name, img_array in self.refs.items():
features = 提取特征点(img_array)
images.append(features)
labels.append(name)
# 训练模型
self.knn_model = self.训练并保存knn模型(images, labels, 模型保存路径)
print(f"KNN模型训练完成,已保存到: {模型保存路径}")
return self
def 使用knn预测(self, 测试图像):
features = 提取特征点(测试图像)
# 预测
预测结果 = self.knn_model.predict([features])
return 预测结果[0]
def match_items_knn_only(
self,
knn_model_path="depot_knn_model.xz",
):
"""仅使用KNN方法进行匹配"""
self.knn_results = []
newstart = None
with lzma.open(knn_model_path, "rb") as f:
self.knn_model = pickle.load(f)
os.makedirs("测试集", exist_ok=True)
for idx, color_sec_np, _ in self.rois:
# KNN预测
roi_gray = cv2.cvtColor(color_sec_np, cv2.COLOR_RGB2GRAY)
knn_name = self.使用knn预测(roi_gray)
self.knn_results.append((idx, knn_name))
os.makedirs(f"测试集/{knn_name}", exist_ok=True)
Image.fromarray(cv2.cvtColor(color_sec_np, cv2.COLOR_BGR2RGB)).crop(
(50, 30, 160, 140)
).save(os.path.join(f"测试集/{knn_name}", f"{idx}_{knn_name}.png"))
# 更新newstart逻辑
newstart = knn_name
print(f"ROI {idx}: Hog+Knn={knn_name}, newstart={newstart}")
return self
def display_results(self):
"""可视化匹配结果"""
ROW_LIMIT = 9
# 获取一个参考图像的尺寸作为空白图像的基础
blank_ref_np = next(iter(self.refs.values()))
blank_img_pil = Image.new(
"RGB", (blank_ref_np.shape[1], blank_ref_np.shape[0]), (200, 200, 200)
)
combined_images = []
current_row_images = []
current_row_width = 0
max_row_height = 0
for idx, color_sec_np, color_roi_data in self.rois:
color_roi_data = Image.fromarray(
cv2.cvtColor(color_roi_data, cv2.COLOR_BGR2RGB)
)
color_sec_np = Image.fromarray(
cv2.cvtColor(color_sec_np, cv2.COLOR_BGR2RGB)
)
# 获取KNN匹配结果
k_res_details = next(
(d for d in getattr(self, "knn_results", []) if d[0] == idx), None
)
k_res_name = k_res_details[1] if k_res_details else None
k_ref_img = (
Image.fromarray(self.refs[k_res_name]).convert("RGB")
if k_res_name and k_res_name in self.refs
else blank_img_pil.copy()
)
# 计算组合尺寸
combined_width = color_roi_data.width + color_sec_np.width + k_ref_img.width
combined_height = max(
color_roi_data.height,
color_sec_np.height,
k_ref_img.height,
)
# 创建组合图像
combined = Image.new(
"RGB", (combined_width, combined_height), (255, 255, 255)
)
x_offset = 0
# 粘贴各个部分
combined.paste(color_roi_data, (x_offset, 0))
x_offset += color_roi_data.width
combined.paste(color_sec_np, (x_offset, 0))
x_offset += color_sec_np.width
combined.paste(k_ref_img, (x_offset, 0))
x_offset += k_ref_img.width
# 添加标注
draw = ImageDraw.Draw(combined)
font = ImageFont.truetype("msyh.ttc", 16)
label = f"ROI {idx}\nHog+Knn: {k_res_name or 'None'}"
text_color = (0, 0, 0)
draw.text(
(color_roi_data.width, color_sec_np.height),
label,
fill=text_color,
font=font,
)
# 添加边框
combined_bordered = ImageOps.expand(combined, border=2, fill=(0, 0, 0))
current_row_images.append(combined_bordered)
current_row_width += combined_bordered.width
max_row_height = max(max_row_height, combined_bordered.height)
# 检查是否需要换行
if len(current_row_images) == ROW_LIMIT:
row_img = Image.new(
"RGB", (current_row_width, max_row_height), (255, 255, 255)
)
x = 0
for img in current_row_images:
row_img.paste(img, (x, 0))
x += img.width
combined_images.append(row_img)
current_row_images = []
current_row_width = 0
max_row_height = 0
# 处理最后一行
if current_row_images:
row_img = Image.new(
"RGB", (current_row_width, max_row_height), (255, 255, 255)
)
x = 0
for img in current_row_images:
row_img.paste(img, (x, 0))
x += img.width
combined_images.append(row_img)
# 生成最终图像
if combined_images:
total_height = sum(img.height for img in combined_images)
max_width = max(img.width for img in combined_images)
final_img = Image.new("RGB", (max_width, total_height), (255, 255, 255))
y = 0
for img in combined_images:
final_img.paste(img, (0, y))
y += img.height
output_path = "depot_test/output/matches_knn_only.png"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
final_img.save(output_path)
print(f"结果图像已保存至: {output_path}")
return self
if __name__ == "__main__":
# 使用示例
matcher = DepotMatcher()
matcher.load_references()
matcher.训练knn模型()
from datetime import datetime
now_time = datetime.now()
matcher.detect_and_crop()
matcher.match_items_knn_only()
print(datetime.now() - now_time)
matcher.display_results()

View file

@ -1,188 +0,0 @@
import json
import os
from datetime import datetime
from multiprocessing import Pool
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont, ImageOps
from skimage.metrics import structural_similarity
now = datetime.now()
# 配置路径
REF_DIR = r"depot_test\output/test/origin"
ROI_DIR = r"depot_test\output/test/result"
IMG_PATH = r"depot_test\result_refined.png"
REF_TABLE_JSON = "./ArknightsGameData/zh_CN/gamedata/excel/item_table.json"
ICON_DIR = "./ArknightsResource/items/"
# SSIM阈值
SSIM_THRESHOLD = 0.01
# 圆检测参数
HOUGH_PARAMS = dict(
dp=5, minDist=230, param1=50, param2=30, minRadius=90, maxRadius=100
)
CROP_SIZE = 130
BORDER = 26
SECONDARY_SLICE = (slice(30, 140), slice(50, 160))
def load_references(table_json, icon_dir):
data = json.load(open(table_json, encoding="utf-8"))
refs = {}
size = CROP_SIZE * 2 - BORDER * 2
for item in data.get("items", {}).values():
t = item.get("classifyType")
if t not in {"NORMAL", "CONSUME", "MATERIAL"}:
continue
path = os.path.join(icon_dir, f"{item['iconId']}.png")
if not os.path.exists(path):
continue
im = Image.open(path).resize((size, size)).crop((50, 30, 160, 140)).convert("L")
refs[item["name"]] = np.array(im)
print(f"已加载 {len(refs)} 个参考图,保存于 {REF_DIR}")
return refs
def process_circle(idx, circle, img, rois, size, dr):
x, y, r = circle
crop = img[
max(0, y - CROP_SIZE) : min(img.shape[0], y + CROP_SIZE),
max(0, x - CROP_SIZE) : min(img.shape[1], x + CROP_SIZE),
]
c = crop[BORDER:-BORDER, BORDER:-BORDER]
sec = c[SECONDARY_SLICE[0], SECONDARY_SLICE[1]]
gray_sec = cv2.cvtColor(sec, cv2.COLOR_BGR2GRAY)
# 保存 ROI 图像
os.makedirs(ROI_DIR, exist_ok=True)
roi_path = os.path.join(ROI_DIR, f"roi_{idx}.png")
cv2.imwrite(roi_path, gray_sec)
rois.append(gray_sec)
return idx, gray_sec, roi_path
def detect_and_crop(image_path):
img = cv2.imread(image_path)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, **HOUGH_PARAMS)
if circles is None:
print("未检测到圆形区域")
return []
circles = np.round(circles[0]).astype(int)
rois = []
size = CROP_SIZE * 2 - BORDER * 2
dr = img.max() - img.min()
# 使用多进程加速
with Pool() as pool:
results = pool.starmap(
process_circle,
[(idx, circle, img, rois, size, dr) for idx, circle in enumerate(circles)],
)
return [roi for idx, roi, path in results]
def process_match(idx, roi, refs, thresh):
best, score = "Unknown", -1
dr = roi.max() - roi.min()
for name, ref in refs.items():
if roi.shape != ref.shape:
continue
s = structural_similarity(roi, ref, data_range=dr)
if s > score:
best, score = name, s
if score >= thresh:
return idx, best, score
return idx, None, score
def match_ssim(rois, refs, thresh=SSIM_THRESHOLD):
from multiprocessing import Pool
args = [(idx, roi, refs, thresh) for idx, roi in enumerate(rois)]
with Pool(processes=5) as pool:
results = pool.starmap(process_match, args)
stats = {}
match_idx = {}
for idx, name, score in results:
if name:
stats[name] = stats.get(name, 0) + 1
match_idx[idx] = name
print(f"ROI {idx} 匹配结果: {name if name else 'Unknown'} (SSIM={score:.3f})")
return stats, match_idx
def display_matches(rois, match_idx, refs):
ROW_LIMIT = 10 # 每行最多10个
blank_ref = next(iter(refs.values())) # 取一个参考图的尺寸
blank_img = Image.new(
"RGB", (blank_ref.shape[1], blank_ref.shape[0]), (200, 200, 200)
) # 灰色占位图
combined_images = []
row_images = []
row_width = 0
max_height = 0
for idx in range(len(rois)):
roi_img = Image.fromarray(rois[idx]).convert("RGB")
ref_name = match_idx.get(idx)
if ref_name:
ref_img = Image.fromarray(refs[ref_name]).convert("RGB")
else:
ref_img = blank_img.copy()
combined_width = roi_img.width + ref_img.width
combined_height = max(roi_img.height, ref_img.height)
combined = Image.new("RGB", (combined_width, combined_height), (255, 255, 255))
combined.paste(roi_img, (0, 0))
combined.paste(ref_img, (roi_img.width, 0))
draw = ImageDraw.Draw(combined)
font = ImageFont.truetype("msyh.ttc", 20)
label = f"ROI {idx}: {ref_name if ref_name else 'Unknown'}"
draw.text((5, 5), label, fill=(255, 0, 0), font=font)
combined = ImageOps.expand(combined, border=2, fill=(0, 0, 0))
row_images.append(combined)
row_width += combined_width
max_height = max(max_height, combined_height)
if len(row_images) == ROW_LIMIT or idx == len(rois) - 1:
row_img = Image.new("RGB", (row_width, max_height), (255, 255, 255))
x_offset = 0
for img in row_images:
row_img.paste(img, (x_offset, 0))
x_offset += img.width
combined_images.append(row_img)
row_images = []
row_width = 0
max_height = 0
total_width = max(img.width for img in combined_images)
total_height = sum(img.height for img in combined_images)
final_img = Image.new("RGB", (total_width, total_height), (255, 255, 255))
y_offset = 0
for img in combined_images:
final_img.paste(img, (0, y_offset))
y_offset += img.height
final_img.save("all_matches.png") # 或 final_img.save("all_matches.png")
if __name__ == "__main__":
refs = load_references(REF_TABLE_JSON, ICON_DIR)
rois = detect_and_crop(IMG_PATH)
res, match_idx = match_ssim(rois, refs)
print("最终识别结果:", res)
display_matches(rois, match_idx, refs)
print(datetime.now() - now)

View file

@ -1,309 +0,0 @@
import json
import os
from datetime import datetime
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont, ImageOps
from skimage.metrics import structural_similarity as ssim
# 配置路径
REF_DIR = r"depot_test\output/test/origin"
ROI_DIR = r"depot_test\output/test/result"
IMG_PATH = r"depot_test\output\result_template.png"
REF_TABLE_JSON = "./ArknightsGameData/zh_CN/gamedata/excel/item_table.json"
ICON_DIR = "./ArknightsResource/items/"
# 参数
HOUGH_PARAMS = dict(
dp=5, minDist=230, param1=50, param2=30, minRadius=90, maxRadius=100
)
CROP_SIZE = 130
BORDER = 26
SECONDARY_SLICE = (slice(30, 140), slice(50, 160))
def load_references(table_json, icon_dir):
data = json.load(open(table_json, encoding="utf-8"))
refs = {}
size = CROP_SIZE * 2 - BORDER * 2
for item in data.get("items", {}).values():
if item.get("classifyType") not in {"NORMAL", "CONSUME", "MATERIAL"}:
continue
path = os.path.join(icon_dir, f"{item['iconId']}.png")
if not os.path.exists(path):
continue
im = Image.open(path).resize((size, size)).crop((50, 30, 160, 140)).convert("L")
refs[item["name"]] = np.array(im)
print(f"已加载 {len(refs)} 个参考图")
return refs
def process_circle(idx, circle, img):
x, y, r = circle
crop = img[
max(0, y - CROP_SIZE) : min(img.shape[0], y + CROP_SIZE),
max(0, x - CROP_SIZE) : min(img.shape[1], x + CROP_SIZE),
]
# Save color ROI
os.makedirs(ROI_DIR, exist_ok=True)
color_roi_path = os.path.join(ROI_DIR, f"color_roi_{idx}.png")
cv2.imwrite(color_roi_path, crop[BORDER:-BORDER, BORDER:-BORDER])
c = crop[BORDER:-BORDER, BORDER:-BORDER]
sec = c[SECONDARY_SLICE[0], SECONDARY_SLICE[1]]
gray_sec = cv2.cvtColor(sec, cv2.COLOR_BGR2GRAY)
roi_path = os.path.join(ROI_DIR, f"roi_{idx}.png")
cv2.imwrite(roi_path, gray_sec)
return idx, gray_sec, roi_path, color_roi_path
def detect_and_crop(image_path):
img = cv2.imread(image_path)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, **HOUGH_PARAMS)
if circles is None:
print("未检测到圆形区域")
return []
# Convert and sort circles by y then x coordinate
circles = np.round(circles[0]).astype(int)
circles = sorted(circles, key=lambda c: (c[0], c[1])) # Sort by y then x
results = []
for idx, circle in enumerate(circles):
result = process_circle(idx, circle, img)
results.append(result)
return results
def match_template(rois, refs, thresh):
results = []
for idx, roi, _, _ in rois:
best, max_val = "Unknown", -1.0
roi_f = roi.astype(np.float32) / 255.0
for name, ref in refs.items():
ref_f = ref.astype(np.float32) / 255.0
if roi_f.shape != ref_f.shape:
continue
res = cv2.matchTemplate(roi_f, ref_f, cv2.TM_CCOEFF_NORMED)
val = float(res.max())
if val > max_val:
best, max_val = name, val
if max_val >= thresh:
results.append((idx, best, max_val))
else:
results.append((idx, None, max_val))
return results
def match_ssim(rois, refs, thresh):
results = []
for idx, roi, _, _ in rois:
best_match = "Unknown"
max_combined_score = -1.0
best_ssim = 0
best_hist = 0
best_edge = 0
# 预处理ROI
roi_gray = roi if len(roi.shape) == 2 else cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
roi_edges = cv2.Canny(roi_gray, 50, 150)
roi_hist = cv2.calcHist([roi_gray], [0], None, [256], [0, 256])
cv2.normalize(roi_hist, roi_hist, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)
for name, ref in refs.items():
if roi_gray.shape != ref.shape:
continue
# 1. 计算SSIM相似度
ssim_score, _ = ssim(roi_gray, ref, full=True)
# 2. 计算直方图相似度
ref_hist = cv2.calcHist([ref], [0], None, [256], [0, 256])
cv2.normalize(
ref_hist, ref_hist, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX
)
hist_score = cv2.compareHist(roi_hist, ref_hist, cv2.HISTCMP_CORREL)
# 3. 计算边缘相似度
ref_edges = cv2.Canny(ref, 50, 150)
edge_intersection = np.sum(roi_edges * ref_edges)
edge_union = np.sum(roi_edges + ref_edges)
edge_score = edge_intersection / edge_union if edge_union > 0 else 0
# 加权综合评分 (可调整权重)
combined_score = 0.6 * ssim_score + 0.2 * hist_score + 0.2 * edge_score
if combined_score > max_combined_score:
best_match = name
max_combined_score = combined_score
best_ssim = ssim_score
best_hist = hist_score
best_edge = edge_score
# 动态阈值调整 (基于图像复杂度)
roi_complexity = np.std(roi_gray) / 255.0
dynamic_thresh = thresh * (1 + 0.3 * roi_complexity)
if max_combined_score >= dynamic_thresh:
results.append(
(idx, best_match, max_combined_score, best_ssim, best_hist, best_edge)
)
else:
results.append(
(idx, None, max_combined_score, best_ssim, best_hist, best_edge)
)
return results
def display_matches(rois, template_results, ssim_results, refs):
ROW_LIMIT = 9
blank_ref = next(iter(refs.values())) # 取一个参考图的尺寸
blank_img = Image.new(
"RGB", (blank_ref.shape[1], blank_ref.shape[0]), (200, 200, 200)
) # 灰色占位图
combined_images = []
row_images = []
row_width = 0
max_height = 0
for item in rois:
idx, _, _, color_path = item
# Load color ROI image
roi_img = Image.open(color_path).convert("RGB")
# Template matching result
t_res = next((name for i, name, val in template_results if i == idx), None)
t_val = next((val for i, name, val in template_results if i == idx), 0)
if t_res is not None:
t_ref_img = Image.fromarray(refs[t_res]).convert("RGB")
else:
t_ref_img = blank_img.copy()
# SSIM result - we need to get the detailed scores
s_res = next((name for i, name, val in ssim_results if i == idx), None)
s_val = next((val for i, name, val in ssim_results if i == idx), 0)
if s_res is not None:
s_ref_img = Image.fromarray(refs[s_res]).convert("RGB")
else:
s_ref_img = blank_img.copy()
# Combine Template Matching result (left) and SSIM result (right)
combined_width = roi_img.width + t_ref_img.width + s_ref_img.width
combined_height = max(roi_img.height, t_ref_img.height, s_ref_img.height)
combined = Image.new("RGB", (combined_width, combined_height), (255, 255, 255))
combined.paste(roi_img, (0, 0))
combined.paste(t_ref_img, (roi_img.width, 0))
combined.paste(s_ref_img, (roi_img.width + t_ref_img.width, 0))
draw = ImageDraw.Draw(combined)
font = ImageFont.truetype("msyh.ttc", 20)
# Get the detailed scores from the SSIM matching results
ssim_details = next(
(details for i, details in enumerate(ssim_results) if i == idx),
(idx, None, 0, 0, 0, 0), # Default values if not found
)
best_ssim = ssim_details[3] if len(ssim_details) > 3 else 0
best_hist = ssim_details[4] if len(ssim_details) > 4 else 0
best_edge = ssim_details[5] if len(ssim_details) > 5 else 0
label = (
f"ROI {idx} {best_ssim:.3f}{best_hist:.3f} {best_edge:.3f}\n"
f"T({t_res if t_res else 'None'}, {t_val:.3f})\n"
f"S({s_res if s_res else 'None'}, {s_val:.3f})"
)
if t_res == s_res:
draw.text(
(roi_img.width, t_ref_img.height), label, fill=(255, 0, 0), font=font
)
else:
draw.text(
(roi_img.width, t_ref_img.height), label, fill=(255, 0, 255), font=font
)
combined = ImageOps.expand(combined, border=2, fill=(0, 0, 0))
row_images.append(combined)
row_width += combined_width
max_height = max(max_height, combined_height)
if len(row_images) == ROW_LIMIT or idx == len(rois) - 1:
row_img = Image.new("RGB", (row_width, max_height), (255, 255, 255))
x_offset = 0
for img in row_images:
row_img.paste(img, (x_offset, 0))
x_offset += img.width
combined_images.append(row_img)
row_images = []
row_width = 0
max_height = 0
total_width = max(img.width for img in combined_images)
total_height = sum(img.height for img in combined_images)
final_img = Image.new("RGB", (total_width, total_height), (255, 255, 255))
y_offset = 0
for img in combined_images:
final_img.paste(img, (0, y_offset))
y_offset += img.height
final_img.save("depot_test/output/matches_all.png")
if __name__ == "__main__":
refs = load_references(REF_TABLE_JSON, ICON_DIR)
rois = detect_and_crop(IMG_PATH)
now = datetime.now()
Template_MATCH_THRESHOLD = 0.2
print("\n=== 使用 Template Matching ===")
results_template = match_template(rois, refs, Template_MATCH_THRESHOLD)
print("\n模版匹配耗时:", datetime.now() - now)
SSIM_MATCH_THRESHOLD = 0.05
now = datetime.now()
print("\n=== 使用 SSIM ===")
results_ssim = match_ssim(rois, refs, SSIM_MATCH_THRESHOLD)
print("\nSSIM匹配耗时:", datetime.now() - now)
print("\n=== 结果对比 ===")
for idx, _, _, _ in rois:
# Template matching results
t_res = next((name for i, name, val in results_template if i == idx), None)
t_val = next((val for i, name, val in results_template if i == idx), 0)
# SSIM results - now includes detailed metrics
s_res = next(
(name for i, name, val, ssim, hist, edge in results_ssim if i == idx), None
)
s_val = next(
(val for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
)
s_ssim = next(
(ssim for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
)
s_hist = next(
(hist for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
)
s_edge = next(
(edge for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
)
print(
f"ROI {idx}:\n"
f" Template=({t_res if t_res else 'None'}, {t_val:.3f})\n"
f" SSIM=({s_res if s_res else 'None'}, {s_val:.3f})\n"
f" Details: SSIM={s_ssim:.3f}, Hist={s_hist:.3f}, Edge={s_edge:.3f}"
)
# Displaying matches side by side using the updated function
display_matches(rois, results_template, results_ssim, refs)

View file

@ -1,192 +0,0 @@
import json
import os
from datetime import datetime
from multiprocessing import Pool
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont, ImageOps
now = datetime.now()
# 配置路径
REF_DIR = r"depot_test\output/test/origin"
ROI_DIR = r"depot_test\output/test/result"
IMG_PATH = r"depot_test\stitched_image_multi.png"
REF_TABLE_JSON = "./ArknightsGameData/zh_CN/gamedata/excel/item_table.json"
ICON_DIR = "./ArknightsResource/items/"
# SSIM阈值
MATCH_THRESHOLD = 0.01
# 圆检测参数
HOUGH_PARAMS = dict(
dp=5, minDist=230, param1=50, param2=30, minRadius=90, maxRadius=100
)
CROP_SIZE = 130
BORDER = 26
SECONDARY_SLICE = (slice(30, 140), slice(50, 160))
def load_references(table_json, icon_dir):
data = json.load(open(table_json, encoding="utf-8"))
refs = {}
size = CROP_SIZE * 2 - BORDER * 2
for item in data.get("items", {}).values():
t = item.get("classifyType")
if t not in {"NORMAL", "CONSUME", "MATERIAL"}:
continue
path = os.path.join(icon_dir, f"{item['iconId']}.png")
if not os.path.exists(path):
continue
im = Image.open(path).resize((size, size)).crop((50, 30, 160, 140)).convert("L")
refs[item["name"]] = np.array(im)
print(f"已加载 {len(refs)} 个参考图,保存于 {REF_DIR}")
return refs
def process_circle(idx, circle, img, rois, size, dr):
x, y, r = circle
crop = img[
max(0, y - CROP_SIZE) : min(img.shape[0], y + CROP_SIZE),
max(0, x - CROP_SIZE) : min(img.shape[1], x + CROP_SIZE),
]
c = crop[BORDER:-BORDER, BORDER:-BORDER]
sec = c[SECONDARY_SLICE[0], SECONDARY_SLICE[1]]
gray_sec = cv2.cvtColor(sec, cv2.COLOR_BGR2GRAY)
# 保存 ROI 图像
os.makedirs(ROI_DIR, exist_ok=True)
roi_path = os.path.join(ROI_DIR, f"roi_{idx}.png")
cv2.imwrite(roi_path, gray_sec)
rois.append(gray_sec)
return idx, gray_sec, roi_path
def detect_and_crop(image_path):
img = cv2.imread(image_path)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, **HOUGH_PARAMS)
if circles is None:
print("未检测到圆形区域")
return []
circles = np.round(circles[0]).astype(int)
rois = []
size = CROP_SIZE * 2 - BORDER * 2
dr = img.max() - img.min()
# 使用多进程加速
with Pool() as pool:
results = pool.starmap(
process_circle,
[(idx, circle, img, rois, size, dr) for idx, circle in enumerate(circles)],
)
return [roi for idx, roi, path in results]
def process_match(idx, roi, refs, thresh):
best, max_val = "Unknown", -1.0
# 模板匹配需要浮点图
roi_f = roi.astype(np.float32) / 255.0
for name, ref in refs.items():
# ref 已经是灰度 NumPy 数组
ref_f = ref.astype(np.float32) / 255.0
if roi_f.shape != ref_f.shape:
continue
# 使用 TM_CCOEFF_NORMED,结果越接近 1 越匹配
res = cv2.matchTemplate(roi_f, ref_f, cv2.TM_CCOEFF_NORMED)
val = float(res.max())
if val > max_val:
best, max_val = name, val
if max_val >= thresh:
return idx, best, max_val
return idx, None, max_val
def match_ssim(rois, refs, thresh=MATCH_THRESHOLD):
from multiprocessing import Pool
args = [(idx, roi, refs, thresh) for idx, roi in enumerate(rois)]
with Pool(processes=5) as pool:
results = pool.starmap(process_match, args)
stats = {}
match_idx = {}
for idx, name, score in results:
if name:
stats[name] = stats.get(name, 0) + 1
match_idx[idx] = name
print(f"ROI {idx} 匹配结果: {name if name else 'Unknown'} (SSIM={score:.3f})")
return stats, match_idx
def display_matches(rois, match_idx, refs):
ROW_LIMIT = 10 # 每行最多10个
blank_ref = next(iter(refs.values())) # 取一个参考图的尺寸
blank_img = Image.new(
"RGB", (blank_ref.shape[1], blank_ref.shape[0]), (200, 200, 200)
) # 灰色占位图
combined_images = []
row_images = []
row_width = 0
max_height = 0
for idx in range(len(rois)):
roi_img = Image.fromarray(rois[idx]).convert("RGB")
ref_name = match_idx.get(idx)
if ref_name:
ref_img = Image.fromarray(refs[ref_name]).convert("RGB")
else:
ref_img = blank_img.copy()
combined_width = roi_img.width + ref_img.width
combined_height = max(roi_img.height, ref_img.height)
combined = Image.new("RGB", (combined_width, combined_height), (255, 255, 255))
combined.paste(roi_img, (0, 0))
combined.paste(ref_img, (roi_img.width, 0))
draw = ImageDraw.Draw(combined)
font = ImageFont.truetype("msyh.ttc", 20)
label = f"ROI {idx}: {ref_name if ref_name else 'Unknown'}"
draw.text((5, 5), label, fill=(255, 0, 0), font=font)
combined = ImageOps.expand(combined, border=2, fill=(0, 0, 0))
row_images.append(combined)
row_width += combined_width
max_height = max(max_height, combined_height)
if len(row_images) == ROW_LIMIT or idx == len(rois) - 1:
row_img = Image.new("RGB", (row_width, max_height), (255, 255, 255))
x_offset = 0
for img in row_images:
row_img.paste(img, (x_offset, 0))
x_offset += img.width
combined_images.append(row_img)
row_images = []
row_width = 0
max_height = 0
total_width = max(img.width for img in combined_images)
total_height = sum(img.height for img in combined_images)
final_img = Image.new("RGB", (total_width, total_height), (255, 255, 255))
y_offset = 0
for img in combined_images:
final_img.paste(img, (0, y_offset))
y_offset += img.height
final_img.save("all_matches.png") # 或 final_img.save("all_matches.png")
if __name__ == "__main__":
refs = load_references(REF_TABLE_JSON, ICON_DIR)
rois = detect_and_crop(IMG_PATH)
res, match_idx = match_ssim(rois, refs)
print("最终识别结果:", res)
display_matches(rois, match_idx, refs)
print(datetime.now() - now)

130
depot_test/训练.py Normal file
View file

@ -0,0 +1,130 @@
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
# 1. 网络定义
class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=5), # (3,110,110) -> (32,106,106)
nn.ReLU(),
nn.MaxPool2d(2), # -> (32,53,53)
nn.Conv2d(32, 64, kernel_size=5), # -> (64,49,49)
nn.ReLU(),
nn.MaxPool2d(2), # -> (64,24,24)
)
self.fc = nn.Sequential(
nn.Linear(64 * 24 * 24, 512), nn.ReLU(), nn.Linear(512, 128)
)
def forward_once(self, x):
x = self.cnn(x)
x = x.view(x.size(0), -1)
return self.fc(x)
def forward(self, input1, input2):
output1 = self.forward_once(input1)
output2 = self.forward_once(input2)
return output1, output2
# 2. Contrastive Loss
class ContrastiveLoss(nn.Module):
def __init__(self, margin=2.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, out1, out2, label):
dist = F.pairwise_distance(out1, out2)
loss = label * torch.pow(dist, 2) + (1 - label) * torch.pow(
torch.clamp(self.margin - dist, min=0.0), 2
)
return loss.mean()
# 3. Dataset 构造器
class SiameseDataset(Dataset):
def __init__(self, folder_path, transform=None):
self.folder_path = folder_path
self.classes = os.listdir(folder_path)
self.transform = transform or transforms.ToTensor()
def __getitem__(self, index):
class1 = random.choice(self.classes)
class2 = (
class1
if random.random() < 0.5
else random.choice([c for c in self.classes if c != class1])
)
img1_path = os.path.join(
self.folder_path,
class1,
random.choice(os.listdir(os.path.join(self.folder_path, class1))),
)
img2_path = os.path.join(
self.folder_path,
class2,
random.choice(os.listdir(os.path.join(self.folder_path, class2))),
)
img1 = Image.open(img1_path).convert("RGB").resize((110, 110))
img2 = Image.open(img2_path).convert("RGB").resize((110, 110))
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
label = 1.0 if class1 == class2 else 0.0
return img1, img2, torch.tensor([label], dtype=torch.float32)
def __len__(self):
return 5000
# 4. 训练入口
def train():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
dataset = SiameseDataset("dataset/train", transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
model = SiameseNetwork().to(device)
criterion = ContrastiveLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(100):
total_loss = 0
for img1, img2, label in tqdm(dataloader, desc=f"Epoch {epoch + 1}"):
img1, img2, label = img1.to(device), img2.to(device), label.to(device)
out1, out2 = model(img1, img2)
loss = criterion(out1, out2, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader):.4f}")
torch.save(model.state_dict(), "siamese_model.pth")
print("Model saved as siamese_model.pth")
if __name__ == "__main__":
train()

74
depot_test/预测.py Normal file
View file

@ -0,0 +1,74 @@
import os
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from 训练 import SiameseNetwork # 确保和训练脚本在同一目录
# 加载模型
def load_model(model_path, device):
model = SiameseNetwork().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model
# 读取训练集中每个类别的图像(作为支持集)
def load_train_class_images(train_dir, transform, device):
class_images = {}
for class_name in os.listdir(train_dir):
class_path = os.path.join(train_dir, class_name)
if not os.path.isdir(class_path):
continue
images = []
for img_name in os.listdir(class_path):
img_path = os.path.join(class_path, img_name)
image = Image.open(img_path).convert("RGB").resize((110, 110))
image = transform(image).unsqueeze(0).to(device) # shape: (1,3,110,110)
images.append(image)
if images:
class_images[class_name] = images
return class_images
# 推理函数:返回最相似的类别名
def predict(model, test_img, class_images):
min_dist = float("inf")
predicted_class = None
with torch.no_grad():
for class_name, ref_images in class_images.items():
for ref_img in ref_images:
out1, out2 = model(test_img, ref_img)
dist = F.pairwise_distance(out1, out2)
if dist.item() < min_dist:
min_dist = dist.item()
predicted_class = class_name
return predicted_class, min_dist
# 主推理流程
def infer():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model("siamese_model.pth", device)
transform = transforms.Compose([
transforms.ToTensor(),
])
train_dir = "dataset/train"
test_dir = "dataset/test"
class_images = load_train_class_images(train_dir, transform, device)
print("开始测试...")
for class_name in os.listdir(test_dir):
class_path = os.path.join(test_dir, class_name)
for img_name in os.listdir(class_path):
img_path = os.path.join(class_path, img_name)
img = Image.open(img_path).convert("RGB").resize((110, 110))
img_tensor = transform(img).unsqueeze(0).to(device)
predicted_class, dist = predict(model, img_tensor, class_images)
print(f"Test Image: {img_name} | True: {class_name} | Predicted: {predicted_class} | Distance: {dist:.4f}")
if __name__ == "__main__":
infer()