上传一些测试文件
This commit is contained in:
parent
cee44523c6
commit
44782a3117
10 changed files with 912 additions and 689 deletions
BIN
depot_test/output/matches_knn_only.png
(Stored with Git LFS)
Normal file
BIN
depot_test/output/matches_knn_only.png
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
depot_test/stitched_image_multi.png
(Stored with Git LFS)
Normal file
BIN
depot_test/stitched_image_multi.png
(Stored with Git LFS)
Normal file
Binary file not shown.
324
depot_test/仓库识别MobileNetV3 KNN copy.py
Normal file
324
depot_test/仓库识别MobileNetV3 KNN copy.py
Normal file
|
@ -0,0 +1,324 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import re # 用于解析文件名中的标签
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from torchvision import models, transforms
|
||||||
|
from torchvision.datasets import ImageFolder
|
||||||
|
from tqdm import tqdm # 用于显示进度条
|
||||||
|
|
||||||
|
# --- 配置参数 ---
|
||||||
|
TRAIN_DIR = "训练集"
|
||||||
|
VAL_TEST_DIR = "测试集" # 根据你的描述,验证集和测试集是同一个目录,文件格式相同
|
||||||
|
IMAGE_SIZE = 224 # MobileNetV3 的标准输入大小
|
||||||
|
BATCH_SIZE = 32
|
||||||
|
NUM_EPOCHS = 20 # 可以根据需要调整
|
||||||
|
LEARNING_RATE = 0.001 # 初始学习率
|
||||||
|
SAVE_MODEL_PATH = "mobilenetv3_small_finetuned.pth"
|
||||||
|
|
||||||
|
# 自动检测设备
|
||||||
|
device = torch.device("cuda")
|
||||||
|
print(f"使用设备: {device}")
|
||||||
|
|
||||||
|
# --- 数据预处理和增强 ---
|
||||||
|
# ImageNet 标准均值和标准差
|
||||||
|
mean = [0.485, 0.456, 0.406]
|
||||||
|
std = [0.229, 0.224, 0.225]
|
||||||
|
|
||||||
|
# 训练集数据增强和预处理
|
||||||
|
train_transforms = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.RandomResizedCrop(IMAGE_SIZE), # 随机裁剪并缩放
|
||||||
|
transforms.RandomHorizontalFlip(), # 随机水平翻转
|
||||||
|
transforms.ToTensor(), # 转换为 Tensor
|
||||||
|
transforms.Normalize(mean, std), # 标准化
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证/测试集数据预处理 (不需要数据增强,只需要中心裁剪和标准化)
|
||||||
|
val_test_transforms = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize(int(IMAGE_SIZE * 256 / 224)), # 缩放到较大尺寸
|
||||||
|
transforms.CenterCrop(IMAGE_SIZE), # 中心裁剪到目标尺寸
|
||||||
|
transforms.ToTensor(), # 转换为 Tensor
|
||||||
|
transforms.Normalize(mean, std), # 标准化
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- 自定义测试集 Dataset ---
|
||||||
|
# 需要一个自定义 Dataset 来处理 "测试集/{idx}_{label}.png" 这种文件命名格式
|
||||||
|
class CustomValTestDataset(Dataset):
|
||||||
|
def __init__(self, root_dir, class_to_idx, transform=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
root_dir (string): 数据集根目录 (e.g., "测试集").
|
||||||
|
class_to_idx (dict): 从类别名称到索引的映射,与训练集一致。
|
||||||
|
transform (callable, optional): 应用于图像的转换.
|
||||||
|
"""
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.transform = transform
|
||||||
|
self.class_to_idx = class_to_idx
|
||||||
|
self.idx_to_class = {v: k for k, v in class_to_idx.items()}
|
||||||
|
self.image_files = []
|
||||||
|
self.labels = []
|
||||||
|
|
||||||
|
# 遍历目录下的所有png文件
|
||||||
|
filepaths = glob.glob(os.path.join(root_dir, "*.png"))
|
||||||
|
|
||||||
|
# 解析文件名,提取标签
|
||||||
|
pattern = re.compile(r"^\d+_([^_]+)\.png$") # 匹配 数字_标签.png
|
||||||
|
|
||||||
|
for filepath in filepaths:
|
||||||
|
filename = os.path.basename(filepath)
|
||||||
|
match = pattern.match(filename)
|
||||||
|
if match:
|
||||||
|
label_name = match.group(1)
|
||||||
|
if label_name in self.class_to_idx:
|
||||||
|
self.image_files.append(filepath)
|
||||||
|
self.labels.append(self.class_to_idx[label_name])
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"警告: 文件 '{filename}' 中的标签 '{label_name}' 不在训练集的类别中,将跳过。"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"加载了 {len(self.image_files)} 张验证/测试图片。")
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_files)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if torch.is_tensor(idx):
|
||||||
|
idx = idx.tolist()
|
||||||
|
|
||||||
|
img_path = self.image_files[idx]
|
||||||
|
label = self.labels[idx]
|
||||||
|
|
||||||
|
# 打开图像,确保是 RGB (处理可能的灰度图)
|
||||||
|
img = Image.open(img_path).convert("RGB")
|
||||||
|
|
||||||
|
if self.transform:
|
||||||
|
img = self.transform(img)
|
||||||
|
|
||||||
|
return img, label
|
||||||
|
|
||||||
|
|
||||||
|
# --- 加载数据 ---
|
||||||
|
# 使用 ImageFolder 加载训练集,它会自动从目录名解析类别
|
||||||
|
if not os.path.exists(TRAIN_DIR):
|
||||||
|
print(
|
||||||
|
f"错误: 训练集目录 '{TRAIN_DIR}' 不存在。请创建该目录并放入分类好的图片子目录。"
|
||||||
|
)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not os.path.exists(VAL_TEST_DIR):
|
||||||
|
print(f"错误: 验证/测试集目录 '{VAL_TEST_DIR}' 不存在。请创建该目录并放入图片。")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
train_dataset = ImageFolder(TRAIN_DIR, transform=train_transforms)
|
||||||
|
num_classes = len(train_dataset.classes)
|
||||||
|
class_to_idx = train_dataset.class_to_idx # 获取类别到索引的映射
|
||||||
|
|
||||||
|
print(f"从训练集检测到 {num_classes} 个类别: {train_dataset.classes}")
|
||||||
|
|
||||||
|
# 使用自定义 Dataset 加载验证/测试集
|
||||||
|
val_test_dataset = CustomValTestDataset(
|
||||||
|
VAL_TEST_DIR, class_to_idx, transform=val_test_transforms
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建 DataLoader
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset, batch_size=BATCH_SIZE, shuffle=True
|
||||||
|
) # num_workers 根据你的机器性能调整
|
||||||
|
val_test_loader = DataLoader(val_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||||||
|
|
||||||
|
# --- 加载预训练的 MobileNetV3-Small 模型 ---
|
||||||
|
# 使用 weights 参数来指定预训练权重
|
||||||
|
# MobileNetV3_Small_Weights.IMAGENET1K_V1 是在 ImageNet 上预训练的权重
|
||||||
|
try:
|
||||||
|
model = models.mobilenet_v3_small(
|
||||||
|
weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
|
||||||
|
)
|
||||||
|
print("成功加载预训练的 MobileNetV3-Small 模型 (ImageNet weights)。")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载预训练模型失败: {e}")
|
||||||
|
print("尝试加载不带权重的模型...")
|
||||||
|
model = models.mobilenet_v3_small(weights=None)
|
||||||
|
|
||||||
|
|
||||||
|
# --- 修改全连接层以匹配新的类别数量 ---
|
||||||
|
# MobileNetV3 的分类器是 model.classifier
|
||||||
|
# 最后一个线性层是 classifier[-1]
|
||||||
|
num_ftrs = model.classifier[-1].in_features
|
||||||
|
# 替换掉原来的全连接层
|
||||||
|
model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
# --- 定义损失函数和优化器 ---
|
||||||
|
criterion = nn.CrossEntropyLoss() # 交叉熵损失适用于分类问题
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) # Adam 优化器
|
||||||
|
|
||||||
|
# 可选:学习率调度器,帮助调整学习率
|
||||||
|
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) # 每7个epoch降低学习率
|
||||||
|
|
||||||
|
# --- 训练和评估函数 ---
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(model, train_loader, criterion, optimizer, device):
|
||||||
|
model.train() # 设置模型为训练模式
|
||||||
|
running_loss = 0.0
|
||||||
|
correct_predictions = 0
|
||||||
|
total_samples = 0
|
||||||
|
|
||||||
|
# 使用 tqdm 显示进度条
|
||||||
|
for inputs, labels in tqdm(train_loader, desc="训练中"):
|
||||||
|
inputs, labels = inputs.to(device), labels.to(device)
|
||||||
|
|
||||||
|
# 梯度清零
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
outputs = model(inputs)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
|
# 反向传播和优化
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# 统计
|
||||||
|
running_loss += loss.item() * inputs.size(0) # 累加 batch loss * batch size
|
||||||
|
_, predicted = torch.max(outputs, 1) # 获取预测结果
|
||||||
|
correct_predictions += (predicted == labels).sum().item()
|
||||||
|
total_samples += labels.size(0)
|
||||||
|
|
||||||
|
epoch_loss = running_loss / total_samples
|
||||||
|
epoch_accuracy = correct_predictions / total_samples
|
||||||
|
return epoch_loss, epoch_accuracy
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(model, data_loader, criterion, device, desc="评估中"):
|
||||||
|
model.eval() # 设置模型为评估模式
|
||||||
|
running_loss = 0.0
|
||||||
|
correct_predictions = 0
|
||||||
|
total_samples = 0
|
||||||
|
|
||||||
|
# 在评估阶段不计算梯度
|
||||||
|
with torch.no_grad():
|
||||||
|
# 使用 tqdm 显示进度条
|
||||||
|
for inputs, labels in tqdm(data_loader, desc=desc):
|
||||||
|
inputs, labels = inputs.to(device), labels.to(device)
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
outputs = model(inputs)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
|
# 统计
|
||||||
|
running_loss += loss.item() * inputs.size(0)
|
||||||
|
_, predicted = torch.max(outputs, 1)
|
||||||
|
correct_predictions += (predicted == labels).sum().item()
|
||||||
|
total_samples += labels.size(0)
|
||||||
|
|
||||||
|
epoch_loss = running_loss / total_samples
|
||||||
|
epoch_accuracy = correct_predictions / total_samples
|
||||||
|
return epoch_loss, epoch_accuracy
|
||||||
|
|
||||||
|
|
||||||
|
# --- 训练循环 ---
|
||||||
|
best_val_accuracy = 0.0
|
||||||
|
|
||||||
|
print("\n开始训练...")
|
||||||
|
for epoch in range(NUM_EPOCHS):
|
||||||
|
print(f"\n--- Epoch {epoch + 1}/{NUM_EPOCHS} ---")
|
||||||
|
|
||||||
|
# 训练阶段
|
||||||
|
train_loss, train_accuracy = train_epoch(
|
||||||
|
model, train_loader, criterion, optimizer, device
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Epoch {epoch + 1} 训练 Loss: {train_loss:.4f}, 准确率: {train_accuracy:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 可选:学习率调度
|
||||||
|
# if scheduler is not None:
|
||||||
|
# scheduler.step()
|
||||||
|
|
||||||
|
# 验证/测试阶段
|
||||||
|
val_loss, val_accuracy = evaluate(
|
||||||
|
model, val_test_loader, criterion, device, desc="验证/测试中"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Epoch {epoch + 1} 验证/测试 Loss: {val_loss:.4f}, 准确率: {val_accuracy:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存最优模型
|
||||||
|
if val_accuracy > best_val_accuracy:
|
||||||
|
best_val_accuracy = val_accuracy
|
||||||
|
torch.save(model.state_dict(), SAVE_MODEL_PATH)
|
||||||
|
print(
|
||||||
|
f"保存了验证/测试集上最优的模型 (准确率: {best_val_accuracy:.4f}) 到 {SAVE_MODEL_PATH}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n训练完成!")
|
||||||
|
print(f"在验证/测试集上的最高准确率: {best_val_accuracy:.4f}")
|
||||||
|
|
||||||
|
# --- 可选: 加载并测试最优模型 ---
|
||||||
|
# 加载保存的最优模型进行最终测试 (如果验证/测试集是同一个,这就是最终结果)
|
||||||
|
print(f"\n加载最优模型 '{SAVE_MODEL_PATH}' 进行最终评估...")
|
||||||
|
loaded_model = models.mobilenet_v3_small(weights=None) # 先加载一个空的模型结构
|
||||||
|
loaded_model.classifier[-1] = nn.Linear(
|
||||||
|
loaded_model.classifier[-1].in_features, num_classes
|
||||||
|
) # 修改分类器
|
||||||
|
loaded_model.load_state_dict(
|
||||||
|
torch.load(SAVE_MODEL_PATH, map_location=device)
|
||||||
|
) # 加载权重
|
||||||
|
loaded_model = loaded_model.to(device)
|
||||||
|
|
||||||
|
final_test_loss, final_test_accuracy = evaluate(
|
||||||
|
loaded_model, val_test_loader, criterion, device, desc="最终测试中"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"\n最终测试 Loss: {final_test_loss:.4f}, 最终测试准确率: {final_test_accuracy:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- 可选: 预测单张图片 ---
|
||||||
|
|
||||||
|
# 假设你想预测一张名为 '测试集/some_image_X_label.png' 的图片
|
||||||
|
def predict_single_image(image_path, model, class_to_idx, device, transform):
|
||||||
|
model.eval()
|
||||||
|
idx_to_class = {v: k for k, v in class_to_idx.items()}
|
||||||
|
|
||||||
|
try:
|
||||||
|
img = Image.open(image_path).convert("RGB")
|
||||||
|
img = transform(img).unsqueeze(0).to(device) # 添加 batch 维度并移动到设备
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(img)
|
||||||
|
probabilities = torch.softmax(outputs, dim=1)[0] # 获取概率分布
|
||||||
|
_, predicted_idx = torch.max(probabilities, 0) # 获取最高概率的索引
|
||||||
|
predicted_label = idx_to_class[predicted_idx.item()]
|
||||||
|
confidence = probabilities[predicted_idx].item()
|
||||||
|
|
||||||
|
print(f"\n预测图片: {image_path}")
|
||||||
|
print(f"预测类别: {predicted_label}, 置信度: {confidence:.4f}")
|
||||||
|
return predicted_label, confidence
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"错误: 图片文件 '{image_path}' 未找到。")
|
||||||
|
return None, None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"预测图片时发生错误: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
# # 示例预测 (取消注释以使用)
|
||||||
|
example_image_path = r"测试集\27_基础作战记录.png" # 替换为你测试集中的实际文件路径
|
||||||
|
predict_single_image(
|
||||||
|
example_image_path, loaded_model, class_to_idx, device, val_test_transforms
|
||||||
|
)
|
378
depot_test/仓库识别MobileNetV3 KNN.py
Normal file
378
depot_test/仓库识别MobileNetV3 KNN.py
Normal file
|
@ -0,0 +1,378 @@
|
||||||
|
import json
|
||||||
|
import lzma
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
||||||
|
from sklearn.neighbors import KNeighborsClassifier
|
||||||
|
from torchvision import models
|
||||||
|
|
||||||
|
CROP_SIZE = 130
|
||||||
|
BORDER = 26
|
||||||
|
size = CROP_SIZE * 2 - BORDER * 2
|
||||||
|
|
||||||
|
|
||||||
|
# 定义特征提取器
|
||||||
|
model = models.mobilenet_v3_small(weights="DEFAULT")
|
||||||
|
|
||||||
|
features_part = model.features
|
||||||
|
avgpool = torch.nn.AdaptiveAvgPool2d(1)
|
||||||
|
classifier_part_excluding_last = torch.nn.Sequential(
|
||||||
|
*list(model.classifier.children())[:-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_extractor = torch.nn.Sequential(
|
||||||
|
features_part,
|
||||||
|
avgpool,
|
||||||
|
torch.nn.Flatten(start_dim=1),
|
||||||
|
classifier_part_excluding_last,
|
||||||
|
)
|
||||||
|
feature_extractor.eval() # 切换到评估模式
|
||||||
|
|
||||||
|
|
||||||
|
def 提取特征点(模板):
|
||||||
|
"""使用MobileNetV3提取特征 (PyTorch版)"""
|
||||||
|
# 将输入图像从BGR转换为RGB
|
||||||
|
img_rgb = cv2.cvtColor(模板, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# 定义图像预处理流程
|
||||||
|
preprocess = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToPILImage(), # 转换为PIL图像
|
||||||
|
transforms.Resize(250), # 调整大小为224x224
|
||||||
|
transforms.ToTensor(), # 转换为Tensor
|
||||||
|
transforms.Normalize( # 归一化
|
||||||
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预处理图像
|
||||||
|
img_tensor = preprocess(img_rgb)
|
||||||
|
img_tensor = img_tensor.unsqueeze(0) # 增加batch维度
|
||||||
|
|
||||||
|
# 提取特征
|
||||||
|
with torch.no_grad():
|
||||||
|
features = feature_extractor(img_tensor)
|
||||||
|
|
||||||
|
# 将特征展平为一维
|
||||||
|
features = features.flatten().numpy()
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
class DepotMatcher:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ref_table_json="./ArknightsGameData/zh_CN/gamedata/excel/item_table.json",
|
||||||
|
icon_dir="./ArknightsResource/items/",
|
||||||
|
ref_dir="depot_test/output/test/origin",
|
||||||
|
roi_dir="depot_test/output/test/result",
|
||||||
|
img_path=r"depot_test\stitched_image_multi.png",
|
||||||
|
):
|
||||||
|
# 初始化路径配置
|
||||||
|
self.REF_DIR = ref_dir
|
||||||
|
self.ROI_DIR = roi_dir
|
||||||
|
self.IMG_PATH = img_path
|
||||||
|
self.REF_TABLE_JSON = ref_table_json
|
||||||
|
self.ICON_DIR = icon_dir
|
||||||
|
|
||||||
|
# 初始化算法参数
|
||||||
|
self.HOUGH_PARAMS = dict(
|
||||||
|
dp=5, minDist=230, param1=50, param2=30, minRadius=90, maxRadius=100
|
||||||
|
)
|
||||||
|
self.CROP_SIZE = 130
|
||||||
|
self.BORDER = 26
|
||||||
|
|
||||||
|
# 运行时数据存储
|
||||||
|
self.refs = None
|
||||||
|
self.rois = []
|
||||||
|
self.knn_results = []
|
||||||
|
self.knn_model = None
|
||||||
|
|
||||||
|
def load_references(self):
|
||||||
|
"""加载物品图标参考图(保留彩色)"""
|
||||||
|
data = json.load(open(self.REF_TABLE_JSON, encoding="utf-8"))
|
||||||
|
self.refs = {}
|
||||||
|
size = self.CROP_SIZE * 2 - self.BORDER * 2
|
||||||
|
|
||||||
|
# 首先收集所有带有sortId的物品
|
||||||
|
items_with_sort = []
|
||||||
|
for item in data.get("items", {}).values():
|
||||||
|
if item.get("classifyType") not in {"NORMAL", "CONSUME", "MATERIAL"}:
|
||||||
|
continue
|
||||||
|
|
||||||
|
path = os.path.join(self.ICON_DIR, f"{item['iconId']}.png")
|
||||||
|
if not os.path.exists(path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 保留彩色图像
|
||||||
|
im = Image.open(path).resize((size, size))
|
||||||
|
items_with_sort.append(
|
||||||
|
{
|
||||||
|
"name": item["name"],
|
||||||
|
"array": np.array(im),
|
||||||
|
"sortId": item.get("sortId", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 按sortId排序
|
||||||
|
items_with_sort.sort(key=lambda x: x["sortId"])
|
||||||
|
|
||||||
|
# 创建最终的refs字典
|
||||||
|
for item in items_with_sort:
|
||||||
|
self.refs[item["name"]] = item["array"]
|
||||||
|
|
||||||
|
print(f"已加载 {len(self.refs)} 个参考图 (按sortId排序)")
|
||||||
|
# 保存训练集图像
|
||||||
|
os.makedirs("训练集", exist_ok=True)
|
||||||
|
for name, array in self.refs.items():
|
||||||
|
os.makedirs(f"训练集/{name}", exist_ok=True)
|
||||||
|
path = os.path.join(f"训练集/{name}", f"{name}.png")
|
||||||
|
im = Image.fromarray(array)
|
||||||
|
cropped_im = im.crop((50, 30, 160, 140))
|
||||||
|
cropped_im.save(path)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _process_circle(self, idx, circle, img):
|
||||||
|
"""处理单个圆形区域,返回彩色图像数据"""
|
||||||
|
x, y, r = circle
|
||||||
|
# 裁剪包含圆形的更大区域
|
||||||
|
crop = img[
|
||||||
|
max(0, y - self.CROP_SIZE) : min(img.shape[0], y + self.CROP_SIZE),
|
||||||
|
max(0, x - self.CROP_SIZE) : min(img.shape[1], x + self.CROP_SIZE),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 提取核心的彩色ROI区域
|
||||||
|
color_roi = crop[self.BORDER : -self.BORDER, self.BORDER : -self.BORDER]
|
||||||
|
|
||||||
|
# 提取用于匹配的彩色区域
|
||||||
|
color_sec = color_roi
|
||||||
|
|
||||||
|
return idx, color_sec, color_roi
|
||||||
|
|
||||||
|
def detect_and_crop(self):
|
||||||
|
"""检测并裁剪截图区域"""
|
||||||
|
img = cv2.imread(self.IMG_PATH)
|
||||||
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, **self.HOUGH_PARAMS)
|
||||||
|
|
||||||
|
# 处理检测到的圆形
|
||||||
|
circles = np.round(circles[0]).astype(int)
|
||||||
|
circles = sorted(circles, key=lambda c: (c[0], c[1])) # 按坐标排序
|
||||||
|
|
||||||
|
self.rois = []
|
||||||
|
for idx, circle in enumerate(circles):
|
||||||
|
result = self._process_circle(idx, circle, img)
|
||||||
|
self.rois.append(result)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def 训练并保存knn模型(self, images, labels, filename):
|
||||||
|
"""训练并保存KNN模型"""
|
||||||
|
knn_classifier = KNeighborsClassifier(
|
||||||
|
weights="distance", n_neighbors=1, n_jobs=1
|
||||||
|
)
|
||||||
|
knn_classifier.fit(images, labels)
|
||||||
|
|
||||||
|
with lzma.open(filename, "wb") as f:
|
||||||
|
pickle.dump(knn_classifier, f)
|
||||||
|
|
||||||
|
return knn_classifier
|
||||||
|
|
||||||
|
def 训练knn模型(self, 模型保存路径="depot_knn_model.xz"):
|
||||||
|
"""训练并保存KNN模型(使用彩色图像)"""
|
||||||
|
|
||||||
|
# 准备训练数据
|
||||||
|
images = []
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
for name, img_array in self.refs.items():
|
||||||
|
features = 提取特征点(img_array)
|
||||||
|
images.append(features)
|
||||||
|
labels.append(name)
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
self.knn_model = self.训练并保存knn模型(images, labels, 模型保存路径)
|
||||||
|
print(f"KNN模型训练完成,已保存到: {模型保存路径}")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def 使用knn预测(self, 测试图像):
|
||||||
|
features = 提取特征点(测试图像)
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
预测结果 = self.knn_model.predict([features])
|
||||||
|
|
||||||
|
return 预测结果[0]
|
||||||
|
|
||||||
|
def match_items_knn_only(
|
||||||
|
self,
|
||||||
|
knn_model_path="depot_knn_model.xz",
|
||||||
|
):
|
||||||
|
"""仅使用KNN方法进行匹配"""
|
||||||
|
self.knn_results = []
|
||||||
|
newstart = None
|
||||||
|
with lzma.open(knn_model_path, "rb") as f:
|
||||||
|
self.knn_model = pickle.load(f)
|
||||||
|
|
||||||
|
os.makedirs("测试集", exist_ok=True)
|
||||||
|
|
||||||
|
for idx, color_sec_np, _ in self.rois:
|
||||||
|
# KNN预测
|
||||||
|
roi_gray = cv2.cvtColor(color_sec_np, cv2.COLOR_RGB2GRAY)
|
||||||
|
knn_name = self.使用knn预测(roi_gray)
|
||||||
|
|
||||||
|
self.knn_results.append((idx, knn_name))
|
||||||
|
os.makedirs(f"测试集/{knn_name}", exist_ok=True)
|
||||||
|
Image.fromarray(cv2.cvtColor(color_sec_np, cv2.COLOR_BGR2RGB)).crop(
|
||||||
|
(50, 30, 160, 140)
|
||||||
|
).save(os.path.join(f"测试集/{knn_name}", f"{idx}_{knn_name}.png"))
|
||||||
|
# 更新newstart逻辑
|
||||||
|
newstart = knn_name
|
||||||
|
|
||||||
|
print(f"ROI {idx}: Hog+Knn={knn_name}, newstart={newstart}")
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def display_results(self):
|
||||||
|
"""可视化匹配结果"""
|
||||||
|
ROW_LIMIT = 9
|
||||||
|
|
||||||
|
# 获取一个参考图像的尺寸作为空白图像的基础
|
||||||
|
blank_ref_np = next(iter(self.refs.values()))
|
||||||
|
blank_img_pil = Image.new(
|
||||||
|
"RGB", (blank_ref_np.shape[1], blank_ref_np.shape[0]), (200, 200, 200)
|
||||||
|
)
|
||||||
|
|
||||||
|
combined_images = []
|
||||||
|
current_row_images = []
|
||||||
|
current_row_width = 0
|
||||||
|
max_row_height = 0
|
||||||
|
|
||||||
|
for idx, color_sec_np, color_roi_data in self.rois:
|
||||||
|
color_roi_data = Image.fromarray(
|
||||||
|
cv2.cvtColor(color_roi_data, cv2.COLOR_BGR2RGB)
|
||||||
|
)
|
||||||
|
color_sec_np = Image.fromarray(
|
||||||
|
cv2.cvtColor(color_sec_np, cv2.COLOR_BGR2RGB)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取KNN匹配结果
|
||||||
|
k_res_details = next(
|
||||||
|
(d for d in getattr(self, "knn_results", []) if d[0] == idx), None
|
||||||
|
)
|
||||||
|
k_res_name = k_res_details[1] if k_res_details else None
|
||||||
|
k_ref_img = (
|
||||||
|
Image.fromarray(self.refs[k_res_name]).convert("RGB")
|
||||||
|
if k_res_name and k_res_name in self.refs
|
||||||
|
else blank_img_pil.copy()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算组合尺寸
|
||||||
|
combined_width = color_roi_data.width + color_sec_np.width + k_ref_img.width
|
||||||
|
|
||||||
|
combined_height = max(
|
||||||
|
color_roi_data.height,
|
||||||
|
color_sec_np.height,
|
||||||
|
k_ref_img.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建组合图像
|
||||||
|
combined = Image.new(
|
||||||
|
"RGB", (combined_width, combined_height), (255, 255, 255)
|
||||||
|
)
|
||||||
|
x_offset = 0
|
||||||
|
|
||||||
|
# 粘贴各个部分
|
||||||
|
combined.paste(color_roi_data, (x_offset, 0))
|
||||||
|
x_offset += color_roi_data.width
|
||||||
|
|
||||||
|
combined.paste(color_sec_np, (x_offset, 0))
|
||||||
|
x_offset += color_sec_np.width
|
||||||
|
|
||||||
|
combined.paste(k_ref_img, (x_offset, 0))
|
||||||
|
x_offset += k_ref_img.width
|
||||||
|
|
||||||
|
# 添加标注
|
||||||
|
draw = ImageDraw.Draw(combined)
|
||||||
|
font = ImageFont.truetype("msyh.ttc", 16)
|
||||||
|
|
||||||
|
label = f"ROI {idx}\nHog+Knn: {k_res_name or 'None'}"
|
||||||
|
|
||||||
|
text_color = (0, 0, 0)
|
||||||
|
|
||||||
|
draw.text(
|
||||||
|
(color_roi_data.width, color_sec_np.height),
|
||||||
|
label,
|
||||||
|
fill=text_color,
|
||||||
|
font=font,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加边框
|
||||||
|
combined_bordered = ImageOps.expand(combined, border=2, fill=(0, 0, 0))
|
||||||
|
current_row_images.append(combined_bordered)
|
||||||
|
current_row_width += combined_bordered.width
|
||||||
|
max_row_height = max(max_row_height, combined_bordered.height)
|
||||||
|
|
||||||
|
# 检查是否需要换行
|
||||||
|
if len(current_row_images) == ROW_LIMIT:
|
||||||
|
row_img = Image.new(
|
||||||
|
"RGB", (current_row_width, max_row_height), (255, 255, 255)
|
||||||
|
)
|
||||||
|
x = 0
|
||||||
|
for img in current_row_images:
|
||||||
|
row_img.paste(img, (x, 0))
|
||||||
|
x += img.width
|
||||||
|
combined_images.append(row_img)
|
||||||
|
current_row_images = []
|
||||||
|
current_row_width = 0
|
||||||
|
max_row_height = 0
|
||||||
|
|
||||||
|
# 处理最后一行
|
||||||
|
if current_row_images:
|
||||||
|
row_img = Image.new(
|
||||||
|
"RGB", (current_row_width, max_row_height), (255, 255, 255)
|
||||||
|
)
|
||||||
|
x = 0
|
||||||
|
for img in current_row_images:
|
||||||
|
row_img.paste(img, (x, 0))
|
||||||
|
x += img.width
|
||||||
|
combined_images.append(row_img)
|
||||||
|
|
||||||
|
# 生成最终图像
|
||||||
|
if combined_images:
|
||||||
|
total_height = sum(img.height for img in combined_images)
|
||||||
|
max_width = max(img.width for img in combined_images)
|
||||||
|
final_img = Image.new("RGB", (max_width, total_height), (255, 255, 255))
|
||||||
|
|
||||||
|
y = 0
|
||||||
|
for img in combined_images:
|
||||||
|
final_img.paste(img, (0, y))
|
||||||
|
y += img.height
|
||||||
|
|
||||||
|
output_path = "depot_test/output/matches_knn_only.png"
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
final_img.save(output_path)
|
||||||
|
print(f"结果图像已保存至: {output_path}")
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 使用示例
|
||||||
|
matcher = DepotMatcher()
|
||||||
|
matcher.load_references()
|
||||||
|
matcher.训练knn模型()
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
now_time = datetime.now()
|
||||||
|
|
||||||
|
matcher.detect_and_crop()
|
||||||
|
matcher.match_items_knn_only()
|
||||||
|
print(datetime.now() - now_time)
|
||||||
|
matcher.display_results()
|
|
@ -1,188 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from datetime import datetime
|
|
||||||
from multiprocessing import Pool
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
|
||||||
from skimage.metrics import structural_similarity
|
|
||||||
|
|
||||||
now = datetime.now()
|
|
||||||
# 配置路径
|
|
||||||
REF_DIR = r"depot_test\output/test/origin"
|
|
||||||
ROI_DIR = r"depot_test\output/test/result"
|
|
||||||
IMG_PATH = r"depot_test\result_refined.png"
|
|
||||||
REF_TABLE_JSON = "./ArknightsGameData/zh_CN/gamedata/excel/item_table.json"
|
|
||||||
ICON_DIR = "./ArknightsResource/items/"
|
|
||||||
|
|
||||||
# SSIM阈值
|
|
||||||
SSIM_THRESHOLD = 0.01
|
|
||||||
|
|
||||||
# 圆检测参数
|
|
||||||
HOUGH_PARAMS = dict(
|
|
||||||
dp=5, minDist=230, param1=50, param2=30, minRadius=90, maxRadius=100
|
|
||||||
)
|
|
||||||
CROP_SIZE = 130
|
|
||||||
BORDER = 26
|
|
||||||
SECONDARY_SLICE = (slice(30, 140), slice(50, 160))
|
|
||||||
|
|
||||||
|
|
||||||
def load_references(table_json, icon_dir):
|
|
||||||
data = json.load(open(table_json, encoding="utf-8"))
|
|
||||||
refs = {}
|
|
||||||
size = CROP_SIZE * 2 - BORDER * 2
|
|
||||||
for item in data.get("items", {}).values():
|
|
||||||
t = item.get("classifyType")
|
|
||||||
if t not in {"NORMAL", "CONSUME", "MATERIAL"}:
|
|
||||||
continue
|
|
||||||
path = os.path.join(icon_dir, f"{item['iconId']}.png")
|
|
||||||
if not os.path.exists(path):
|
|
||||||
continue
|
|
||||||
im = Image.open(path).resize((size, size)).crop((50, 30, 160, 140)).convert("L")
|
|
||||||
refs[item["name"]] = np.array(im)
|
|
||||||
print(f"已加载 {len(refs)} 个参考图,保存于 {REF_DIR}")
|
|
||||||
return refs
|
|
||||||
|
|
||||||
|
|
||||||
def process_circle(idx, circle, img, rois, size, dr):
|
|
||||||
x, y, r = circle
|
|
||||||
crop = img[
|
|
||||||
max(0, y - CROP_SIZE) : min(img.shape[0], y + CROP_SIZE),
|
|
||||||
max(0, x - CROP_SIZE) : min(img.shape[1], x + CROP_SIZE),
|
|
||||||
]
|
|
||||||
c = crop[BORDER:-BORDER, BORDER:-BORDER]
|
|
||||||
sec = c[SECONDARY_SLICE[0], SECONDARY_SLICE[1]]
|
|
||||||
gray_sec = cv2.cvtColor(sec, cv2.COLOR_BGR2GRAY)
|
|
||||||
|
|
||||||
# 保存 ROI 图像
|
|
||||||
os.makedirs(ROI_DIR, exist_ok=True)
|
|
||||||
roi_path = os.path.join(ROI_DIR, f"roi_{idx}.png")
|
|
||||||
cv2.imwrite(roi_path, gray_sec)
|
|
||||||
rois.append(gray_sec)
|
|
||||||
return idx, gray_sec, roi_path
|
|
||||||
|
|
||||||
|
|
||||||
def detect_and_crop(image_path):
|
|
||||||
img = cv2.imread(image_path)
|
|
||||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
||||||
circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, **HOUGH_PARAMS)
|
|
||||||
if circles is None:
|
|
||||||
print("未检测到圆形区域")
|
|
||||||
return []
|
|
||||||
circles = np.round(circles[0]).astype(int)
|
|
||||||
rois = []
|
|
||||||
size = CROP_SIZE * 2 - BORDER * 2
|
|
||||||
dr = img.max() - img.min()
|
|
||||||
|
|
||||||
# 使用多进程加速
|
|
||||||
with Pool() as pool:
|
|
||||||
results = pool.starmap(
|
|
||||||
process_circle,
|
|
||||||
[(idx, circle, img, rois, size, dr) for idx, circle in enumerate(circles)],
|
|
||||||
)
|
|
||||||
|
|
||||||
return [roi for idx, roi, path in results]
|
|
||||||
|
|
||||||
|
|
||||||
def process_match(idx, roi, refs, thresh):
|
|
||||||
best, score = "Unknown", -1
|
|
||||||
dr = roi.max() - roi.min()
|
|
||||||
for name, ref in refs.items():
|
|
||||||
if roi.shape != ref.shape:
|
|
||||||
continue
|
|
||||||
s = structural_similarity(roi, ref, data_range=dr)
|
|
||||||
if s > score:
|
|
||||||
best, score = name, s
|
|
||||||
if score >= thresh:
|
|
||||||
return idx, best, score
|
|
||||||
return idx, None, score
|
|
||||||
|
|
||||||
|
|
||||||
def match_ssim(rois, refs, thresh=SSIM_THRESHOLD):
|
|
||||||
from multiprocessing import Pool
|
|
||||||
|
|
||||||
args = [(idx, roi, refs, thresh) for idx, roi in enumerate(rois)]
|
|
||||||
with Pool(processes=5) as pool:
|
|
||||||
results = pool.starmap(process_match, args)
|
|
||||||
|
|
||||||
stats = {}
|
|
||||||
match_idx = {}
|
|
||||||
for idx, name, score in results:
|
|
||||||
if name:
|
|
||||||
stats[name] = stats.get(name, 0) + 1
|
|
||||||
match_idx[idx] = name
|
|
||||||
print(f"ROI {idx} 匹配结果: {name if name else 'Unknown'} (SSIM={score:.3f})")
|
|
||||||
return stats, match_idx
|
|
||||||
|
|
||||||
|
|
||||||
def display_matches(rois, match_idx, refs):
|
|
||||||
ROW_LIMIT = 10 # 每行最多10个
|
|
||||||
blank_ref = next(iter(refs.values())) # 取一个参考图的尺寸
|
|
||||||
blank_img = Image.new(
|
|
||||||
"RGB", (blank_ref.shape[1], blank_ref.shape[0]), (200, 200, 200)
|
|
||||||
) # 灰色占位图
|
|
||||||
|
|
||||||
combined_images = []
|
|
||||||
row_images = []
|
|
||||||
|
|
||||||
row_width = 0
|
|
||||||
max_height = 0
|
|
||||||
|
|
||||||
for idx in range(len(rois)):
|
|
||||||
roi_img = Image.fromarray(rois[idx]).convert("RGB")
|
|
||||||
ref_name = match_idx.get(idx)
|
|
||||||
if ref_name:
|
|
||||||
ref_img = Image.fromarray(refs[ref_name]).convert("RGB")
|
|
||||||
else:
|
|
||||||
ref_img = blank_img.copy()
|
|
||||||
|
|
||||||
combined_width = roi_img.width + ref_img.width
|
|
||||||
combined_height = max(roi_img.height, ref_img.height)
|
|
||||||
|
|
||||||
combined = Image.new("RGB", (combined_width, combined_height), (255, 255, 255))
|
|
||||||
combined.paste(roi_img, (0, 0))
|
|
||||||
combined.paste(ref_img, (roi_img.width, 0))
|
|
||||||
|
|
||||||
draw = ImageDraw.Draw(combined)
|
|
||||||
|
|
||||||
font = ImageFont.truetype("msyh.ttc", 20)
|
|
||||||
|
|
||||||
label = f"ROI {idx}: {ref_name if ref_name else 'Unknown'}"
|
|
||||||
draw.text((5, 5), label, fill=(255, 0, 0), font=font)
|
|
||||||
combined = ImageOps.expand(combined, border=2, fill=(0, 0, 0))
|
|
||||||
row_images.append(combined)
|
|
||||||
row_width += combined_width
|
|
||||||
max_height = max(max_height, combined_height)
|
|
||||||
|
|
||||||
if len(row_images) == ROW_LIMIT or idx == len(rois) - 1:
|
|
||||||
row_img = Image.new("RGB", (row_width, max_height), (255, 255, 255))
|
|
||||||
x_offset = 0
|
|
||||||
for img in row_images:
|
|
||||||
row_img.paste(img, (x_offset, 0))
|
|
||||||
x_offset += img.width
|
|
||||||
combined_images.append(row_img)
|
|
||||||
row_images = []
|
|
||||||
row_width = 0
|
|
||||||
max_height = 0
|
|
||||||
|
|
||||||
total_width = max(img.width for img in combined_images)
|
|
||||||
total_height = sum(img.height for img in combined_images)
|
|
||||||
|
|
||||||
final_img = Image.new("RGB", (total_width, total_height), (255, 255, 255))
|
|
||||||
y_offset = 0
|
|
||||||
for img in combined_images:
|
|
||||||
final_img.paste(img, (0, y_offset))
|
|
||||||
y_offset += img.height
|
|
||||||
|
|
||||||
final_img.save("all_matches.png") # 或 final_img.save("all_matches.png")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
refs = load_references(REF_TABLE_JSON, ICON_DIR)
|
|
||||||
rois = detect_and_crop(IMG_PATH)
|
|
||||||
res, match_idx = match_ssim(rois, refs)
|
|
||||||
|
|
||||||
print("最终识别结果:", res)
|
|
||||||
display_matches(rois, match_idx, refs)
|
|
||||||
print(datetime.now() - now)
|
|
|
@ -1,309 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
|
||||||
from skimage.metrics import structural_similarity as ssim
|
|
||||||
|
|
||||||
# 配置路径
|
|
||||||
REF_DIR = r"depot_test\output/test/origin"
|
|
||||||
ROI_DIR = r"depot_test\output/test/result"
|
|
||||||
IMG_PATH = r"depot_test\output\result_template.png"
|
|
||||||
REF_TABLE_JSON = "./ArknightsGameData/zh_CN/gamedata/excel/item_table.json"
|
|
||||||
ICON_DIR = "./ArknightsResource/items/"
|
|
||||||
|
|
||||||
# 参数
|
|
||||||
HOUGH_PARAMS = dict(
|
|
||||||
dp=5, minDist=230, param1=50, param2=30, minRadius=90, maxRadius=100
|
|
||||||
)
|
|
||||||
CROP_SIZE = 130
|
|
||||||
BORDER = 26
|
|
||||||
SECONDARY_SLICE = (slice(30, 140), slice(50, 160))
|
|
||||||
|
|
||||||
|
|
||||||
def load_references(table_json, icon_dir):
|
|
||||||
data = json.load(open(table_json, encoding="utf-8"))
|
|
||||||
refs = {}
|
|
||||||
size = CROP_SIZE * 2 - BORDER * 2
|
|
||||||
for item in data.get("items", {}).values():
|
|
||||||
if item.get("classifyType") not in {"NORMAL", "CONSUME", "MATERIAL"}:
|
|
||||||
continue
|
|
||||||
path = os.path.join(icon_dir, f"{item['iconId']}.png")
|
|
||||||
if not os.path.exists(path):
|
|
||||||
continue
|
|
||||||
im = Image.open(path).resize((size, size)).crop((50, 30, 160, 140)).convert("L")
|
|
||||||
refs[item["name"]] = np.array(im)
|
|
||||||
print(f"已加载 {len(refs)} 个参考图")
|
|
||||||
return refs
|
|
||||||
|
|
||||||
|
|
||||||
def process_circle(idx, circle, img):
|
|
||||||
x, y, r = circle
|
|
||||||
crop = img[
|
|
||||||
max(0, y - CROP_SIZE) : min(img.shape[0], y + CROP_SIZE),
|
|
||||||
max(0, x - CROP_SIZE) : min(img.shape[1], x + CROP_SIZE),
|
|
||||||
]
|
|
||||||
# Save color ROI
|
|
||||||
os.makedirs(ROI_DIR, exist_ok=True)
|
|
||||||
color_roi_path = os.path.join(ROI_DIR, f"color_roi_{idx}.png")
|
|
||||||
cv2.imwrite(color_roi_path, crop[BORDER:-BORDER, BORDER:-BORDER])
|
|
||||||
|
|
||||||
c = crop[BORDER:-BORDER, BORDER:-BORDER]
|
|
||||||
sec = c[SECONDARY_SLICE[0], SECONDARY_SLICE[1]]
|
|
||||||
gray_sec = cv2.cvtColor(sec, cv2.COLOR_BGR2GRAY)
|
|
||||||
roi_path = os.path.join(ROI_DIR, f"roi_{idx}.png")
|
|
||||||
cv2.imwrite(roi_path, gray_sec)
|
|
||||||
return idx, gray_sec, roi_path, color_roi_path
|
|
||||||
|
|
||||||
|
|
||||||
def detect_and_crop(image_path):
|
|
||||||
img = cv2.imread(image_path)
|
|
||||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
||||||
circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, **HOUGH_PARAMS)
|
|
||||||
if circles is None:
|
|
||||||
print("未检测到圆形区域")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Convert and sort circles by y then x coordinate
|
|
||||||
circles = np.round(circles[0]).astype(int)
|
|
||||||
circles = sorted(circles, key=lambda c: (c[0], c[1])) # Sort by y then x
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for idx, circle in enumerate(circles):
|
|
||||||
result = process_circle(idx, circle, img)
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def match_template(rois, refs, thresh):
|
|
||||||
results = []
|
|
||||||
for idx, roi, _, _ in rois:
|
|
||||||
best, max_val = "Unknown", -1.0
|
|
||||||
roi_f = roi.astype(np.float32) / 255.0
|
|
||||||
for name, ref in refs.items():
|
|
||||||
ref_f = ref.astype(np.float32) / 255.0
|
|
||||||
if roi_f.shape != ref_f.shape:
|
|
||||||
continue
|
|
||||||
res = cv2.matchTemplate(roi_f, ref_f, cv2.TM_CCOEFF_NORMED)
|
|
||||||
val = float(res.max())
|
|
||||||
if val > max_val:
|
|
||||||
best, max_val = name, val
|
|
||||||
if max_val >= thresh:
|
|
||||||
results.append((idx, best, max_val))
|
|
||||||
else:
|
|
||||||
results.append((idx, None, max_val))
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def match_ssim(rois, refs, thresh):
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for idx, roi, _, _ in rois:
|
|
||||||
best_match = "Unknown"
|
|
||||||
max_combined_score = -1.0
|
|
||||||
best_ssim = 0
|
|
||||||
best_hist = 0
|
|
||||||
best_edge = 0
|
|
||||||
|
|
||||||
# 预处理ROI
|
|
||||||
roi_gray = roi if len(roi.shape) == 2 else cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
|
|
||||||
roi_edges = cv2.Canny(roi_gray, 50, 150)
|
|
||||||
roi_hist = cv2.calcHist([roi_gray], [0], None, [256], [0, 256])
|
|
||||||
cv2.normalize(roi_hist, roi_hist, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)
|
|
||||||
|
|
||||||
for name, ref in refs.items():
|
|
||||||
if roi_gray.shape != ref.shape:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 1. 计算SSIM相似度
|
|
||||||
ssim_score, _ = ssim(roi_gray, ref, full=True)
|
|
||||||
|
|
||||||
# 2. 计算直方图相似度
|
|
||||||
ref_hist = cv2.calcHist([ref], [0], None, [256], [0, 256])
|
|
||||||
cv2.normalize(
|
|
||||||
ref_hist, ref_hist, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX
|
|
||||||
)
|
|
||||||
hist_score = cv2.compareHist(roi_hist, ref_hist, cv2.HISTCMP_CORREL)
|
|
||||||
|
|
||||||
# 3. 计算边缘相似度
|
|
||||||
ref_edges = cv2.Canny(ref, 50, 150)
|
|
||||||
edge_intersection = np.sum(roi_edges * ref_edges)
|
|
||||||
edge_union = np.sum(roi_edges + ref_edges)
|
|
||||||
edge_score = edge_intersection / edge_union if edge_union > 0 else 0
|
|
||||||
|
|
||||||
# 加权综合评分 (可调整权重)
|
|
||||||
combined_score = 0.6 * ssim_score + 0.2 * hist_score + 0.2 * edge_score
|
|
||||||
|
|
||||||
if combined_score > max_combined_score:
|
|
||||||
best_match = name
|
|
||||||
max_combined_score = combined_score
|
|
||||||
best_ssim = ssim_score
|
|
||||||
best_hist = hist_score
|
|
||||||
best_edge = edge_score
|
|
||||||
|
|
||||||
# 动态阈值调整 (基于图像复杂度)
|
|
||||||
roi_complexity = np.std(roi_gray) / 255.0
|
|
||||||
dynamic_thresh = thresh * (1 + 0.3 * roi_complexity)
|
|
||||||
|
|
||||||
if max_combined_score >= dynamic_thresh:
|
|
||||||
results.append(
|
|
||||||
(idx, best_match, max_combined_score, best_ssim, best_hist, best_edge)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
results.append(
|
|
||||||
(idx, None, max_combined_score, best_ssim, best_hist, best_edge)
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def display_matches(rois, template_results, ssim_results, refs):
|
|
||||||
ROW_LIMIT = 9
|
|
||||||
blank_ref = next(iter(refs.values())) # 取一个参考图的尺寸
|
|
||||||
blank_img = Image.new(
|
|
||||||
"RGB", (blank_ref.shape[1], blank_ref.shape[0]), (200, 200, 200)
|
|
||||||
) # 灰色占位图
|
|
||||||
|
|
||||||
combined_images = []
|
|
||||||
row_images = []
|
|
||||||
|
|
||||||
row_width = 0
|
|
||||||
max_height = 0
|
|
||||||
|
|
||||||
for item in rois:
|
|
||||||
idx, _, _, color_path = item
|
|
||||||
# Load color ROI image
|
|
||||||
roi_img = Image.open(color_path).convert("RGB")
|
|
||||||
|
|
||||||
# Template matching result
|
|
||||||
t_res = next((name for i, name, val in template_results if i == idx), None)
|
|
||||||
t_val = next((val for i, name, val in template_results if i == idx), 0)
|
|
||||||
if t_res is not None:
|
|
||||||
t_ref_img = Image.fromarray(refs[t_res]).convert("RGB")
|
|
||||||
else:
|
|
||||||
t_ref_img = blank_img.copy()
|
|
||||||
|
|
||||||
# SSIM result - we need to get the detailed scores
|
|
||||||
s_res = next((name for i, name, val in ssim_results if i == idx), None)
|
|
||||||
s_val = next((val for i, name, val in ssim_results if i == idx), 0)
|
|
||||||
if s_res is not None:
|
|
||||||
s_ref_img = Image.fromarray(refs[s_res]).convert("RGB")
|
|
||||||
else:
|
|
||||||
s_ref_img = blank_img.copy()
|
|
||||||
|
|
||||||
# Combine Template Matching result (left) and SSIM result (right)
|
|
||||||
combined_width = roi_img.width + t_ref_img.width + s_ref_img.width
|
|
||||||
combined_height = max(roi_img.height, t_ref_img.height, s_ref_img.height)
|
|
||||||
|
|
||||||
combined = Image.new("RGB", (combined_width, combined_height), (255, 255, 255))
|
|
||||||
combined.paste(roi_img, (0, 0))
|
|
||||||
combined.paste(t_ref_img, (roi_img.width, 0))
|
|
||||||
combined.paste(s_ref_img, (roi_img.width + t_ref_img.width, 0))
|
|
||||||
|
|
||||||
draw = ImageDraw.Draw(combined)
|
|
||||||
font = ImageFont.truetype("msyh.ttc", 20)
|
|
||||||
|
|
||||||
# Get the detailed scores from the SSIM matching results
|
|
||||||
ssim_details = next(
|
|
||||||
(details for i, details in enumerate(ssim_results) if i == idx),
|
|
||||||
(idx, None, 0, 0, 0, 0), # Default values if not found
|
|
||||||
)
|
|
||||||
best_ssim = ssim_details[3] if len(ssim_details) > 3 else 0
|
|
||||||
best_hist = ssim_details[4] if len(ssim_details) > 4 else 0
|
|
||||||
best_edge = ssim_details[5] if len(ssim_details) > 5 else 0
|
|
||||||
|
|
||||||
label = (
|
|
||||||
f"ROI {idx} {best_ssim:.3f}{best_hist:.3f} {best_edge:.3f}\n"
|
|
||||||
f"T({t_res if t_res else 'None'}, {t_val:.3f})\n"
|
|
||||||
f"S({s_res if s_res else 'None'}, {s_val:.3f})"
|
|
||||||
)
|
|
||||||
|
|
||||||
if t_res == s_res:
|
|
||||||
draw.text(
|
|
||||||
(roi_img.width, t_ref_img.height), label, fill=(255, 0, 0), font=font
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
draw.text(
|
|
||||||
(roi_img.width, t_ref_img.height), label, fill=(255, 0, 255), font=font
|
|
||||||
)
|
|
||||||
combined = ImageOps.expand(combined, border=2, fill=(0, 0, 0))
|
|
||||||
row_images.append(combined)
|
|
||||||
row_width += combined_width
|
|
||||||
max_height = max(max_height, combined_height)
|
|
||||||
|
|
||||||
if len(row_images) == ROW_LIMIT or idx == len(rois) - 1:
|
|
||||||
row_img = Image.new("RGB", (row_width, max_height), (255, 255, 255))
|
|
||||||
x_offset = 0
|
|
||||||
for img in row_images:
|
|
||||||
row_img.paste(img, (x_offset, 0))
|
|
||||||
x_offset += img.width
|
|
||||||
combined_images.append(row_img)
|
|
||||||
row_images = []
|
|
||||||
row_width = 0
|
|
||||||
max_height = 0
|
|
||||||
|
|
||||||
total_width = max(img.width for img in combined_images)
|
|
||||||
total_height = sum(img.height for img in combined_images)
|
|
||||||
|
|
||||||
final_img = Image.new("RGB", (total_width, total_height), (255, 255, 255))
|
|
||||||
y_offset = 0
|
|
||||||
for img in combined_images:
|
|
||||||
final_img.paste(img, (0, y_offset))
|
|
||||||
y_offset += img.height
|
|
||||||
|
|
||||||
final_img.save("depot_test/output/matches_all.png")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
refs = load_references(REF_TABLE_JSON, ICON_DIR)
|
|
||||||
rois = detect_and_crop(IMG_PATH)
|
|
||||||
|
|
||||||
now = datetime.now()
|
|
||||||
|
|
||||||
Template_MATCH_THRESHOLD = 0.2
|
|
||||||
print("\n=== 使用 Template Matching ===")
|
|
||||||
results_template = match_template(rois, refs, Template_MATCH_THRESHOLD)
|
|
||||||
print("\n模版匹配耗时:", datetime.now() - now)
|
|
||||||
SSIM_MATCH_THRESHOLD = 0.05
|
|
||||||
|
|
||||||
now = datetime.now()
|
|
||||||
print("\n=== 使用 SSIM ===")
|
|
||||||
results_ssim = match_ssim(rois, refs, SSIM_MATCH_THRESHOLD)
|
|
||||||
print("\nSSIM匹配耗时:", datetime.now() - now)
|
|
||||||
|
|
||||||
print("\n=== 结果对比 ===")
|
|
||||||
for idx, _, _, _ in rois:
|
|
||||||
# Template matching results
|
|
||||||
t_res = next((name for i, name, val in results_template if i == idx), None)
|
|
||||||
t_val = next((val for i, name, val in results_template if i == idx), 0)
|
|
||||||
|
|
||||||
# SSIM results - now includes detailed metrics
|
|
||||||
s_res = next(
|
|
||||||
(name for i, name, val, ssim, hist, edge in results_ssim if i == idx), None
|
|
||||||
)
|
|
||||||
s_val = next(
|
|
||||||
(val for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
|
||||||
)
|
|
||||||
s_ssim = next(
|
|
||||||
(ssim for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
|
||||||
)
|
|
||||||
s_hist = next(
|
|
||||||
(hist for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
|
||||||
)
|
|
||||||
s_edge = next(
|
|
||||||
(edge for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"ROI {idx}:\n"
|
|
||||||
f" Template=({t_res if t_res else 'None'}, {t_val:.3f})\n"
|
|
||||||
f" SSIM=({s_res if s_res else 'None'}, {s_val:.3f})\n"
|
|
||||||
f" Details: SSIM={s_ssim:.3f}, Hist={s_hist:.3f}, Edge={s_edge:.3f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Displaying matches side by side using the updated function
|
|
||||||
display_matches(rois, results_template, results_ssim, refs)
|
|
|
@ -1,192 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from datetime import datetime
|
|
||||||
from multiprocessing import Pool
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
|
||||||
|
|
||||||
now = datetime.now()
|
|
||||||
# 配置路径
|
|
||||||
REF_DIR = r"depot_test\output/test/origin"
|
|
||||||
ROI_DIR = r"depot_test\output/test/result"
|
|
||||||
IMG_PATH = r"depot_test\stitched_image_multi.png"
|
|
||||||
REF_TABLE_JSON = "./ArknightsGameData/zh_CN/gamedata/excel/item_table.json"
|
|
||||||
ICON_DIR = "./ArknightsResource/items/"
|
|
||||||
|
|
||||||
# SSIM阈值
|
|
||||||
MATCH_THRESHOLD = 0.01
|
|
||||||
|
|
||||||
# 圆检测参数
|
|
||||||
HOUGH_PARAMS = dict(
|
|
||||||
dp=5, minDist=230, param1=50, param2=30, minRadius=90, maxRadius=100
|
|
||||||
)
|
|
||||||
CROP_SIZE = 130
|
|
||||||
BORDER = 26
|
|
||||||
SECONDARY_SLICE = (slice(30, 140), slice(50, 160))
|
|
||||||
|
|
||||||
|
|
||||||
def load_references(table_json, icon_dir):
|
|
||||||
data = json.load(open(table_json, encoding="utf-8"))
|
|
||||||
refs = {}
|
|
||||||
size = CROP_SIZE * 2 - BORDER * 2
|
|
||||||
for item in data.get("items", {}).values():
|
|
||||||
t = item.get("classifyType")
|
|
||||||
if t not in {"NORMAL", "CONSUME", "MATERIAL"}:
|
|
||||||
continue
|
|
||||||
path = os.path.join(icon_dir, f"{item['iconId']}.png")
|
|
||||||
if not os.path.exists(path):
|
|
||||||
continue
|
|
||||||
im = Image.open(path).resize((size, size)).crop((50, 30, 160, 140)).convert("L")
|
|
||||||
refs[item["name"]] = np.array(im)
|
|
||||||
print(f"已加载 {len(refs)} 个参考图,保存于 {REF_DIR}")
|
|
||||||
return refs
|
|
||||||
|
|
||||||
|
|
||||||
def process_circle(idx, circle, img, rois, size, dr):
|
|
||||||
x, y, r = circle
|
|
||||||
crop = img[
|
|
||||||
max(0, y - CROP_SIZE) : min(img.shape[0], y + CROP_SIZE),
|
|
||||||
max(0, x - CROP_SIZE) : min(img.shape[1], x + CROP_SIZE),
|
|
||||||
]
|
|
||||||
c = crop[BORDER:-BORDER, BORDER:-BORDER]
|
|
||||||
sec = c[SECONDARY_SLICE[0], SECONDARY_SLICE[1]]
|
|
||||||
gray_sec = cv2.cvtColor(sec, cv2.COLOR_BGR2GRAY)
|
|
||||||
|
|
||||||
# 保存 ROI 图像
|
|
||||||
os.makedirs(ROI_DIR, exist_ok=True)
|
|
||||||
roi_path = os.path.join(ROI_DIR, f"roi_{idx}.png")
|
|
||||||
cv2.imwrite(roi_path, gray_sec)
|
|
||||||
rois.append(gray_sec)
|
|
||||||
return idx, gray_sec, roi_path
|
|
||||||
|
|
||||||
|
|
||||||
def detect_and_crop(image_path):
|
|
||||||
img = cv2.imread(image_path)
|
|
||||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
||||||
circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, **HOUGH_PARAMS)
|
|
||||||
if circles is None:
|
|
||||||
print("未检测到圆形区域")
|
|
||||||
return []
|
|
||||||
circles = np.round(circles[0]).astype(int)
|
|
||||||
rois = []
|
|
||||||
size = CROP_SIZE * 2 - BORDER * 2
|
|
||||||
dr = img.max() - img.min()
|
|
||||||
|
|
||||||
# 使用多进程加速
|
|
||||||
with Pool() as pool:
|
|
||||||
results = pool.starmap(
|
|
||||||
process_circle,
|
|
||||||
[(idx, circle, img, rois, size, dr) for idx, circle in enumerate(circles)],
|
|
||||||
)
|
|
||||||
|
|
||||||
return [roi for idx, roi, path in results]
|
|
||||||
|
|
||||||
|
|
||||||
def process_match(idx, roi, refs, thresh):
|
|
||||||
best, max_val = "Unknown", -1.0
|
|
||||||
# 模板匹配需要浮点图
|
|
||||||
roi_f = roi.astype(np.float32) / 255.0
|
|
||||||
for name, ref in refs.items():
|
|
||||||
# ref 已经是灰度 NumPy 数组
|
|
||||||
ref_f = ref.astype(np.float32) / 255.0
|
|
||||||
if roi_f.shape != ref_f.shape:
|
|
||||||
continue
|
|
||||||
# 使用 TM_CCOEFF_NORMED,结果越接近 1 越匹配
|
|
||||||
res = cv2.matchTemplate(roi_f, ref_f, cv2.TM_CCOEFF_NORMED)
|
|
||||||
val = float(res.max())
|
|
||||||
if val > max_val:
|
|
||||||
best, max_val = name, val
|
|
||||||
if max_val >= thresh:
|
|
||||||
return idx, best, max_val
|
|
||||||
return idx, None, max_val
|
|
||||||
|
|
||||||
|
|
||||||
def match_ssim(rois, refs, thresh=MATCH_THRESHOLD):
|
|
||||||
from multiprocessing import Pool
|
|
||||||
|
|
||||||
args = [(idx, roi, refs, thresh) for idx, roi in enumerate(rois)]
|
|
||||||
with Pool(processes=5) as pool:
|
|
||||||
results = pool.starmap(process_match, args)
|
|
||||||
|
|
||||||
stats = {}
|
|
||||||
match_idx = {}
|
|
||||||
for idx, name, score in results:
|
|
||||||
if name:
|
|
||||||
stats[name] = stats.get(name, 0) + 1
|
|
||||||
match_idx[idx] = name
|
|
||||||
print(f"ROI {idx} 匹配结果: {name if name else 'Unknown'} (SSIM={score:.3f})")
|
|
||||||
return stats, match_idx
|
|
||||||
|
|
||||||
|
|
||||||
def display_matches(rois, match_idx, refs):
|
|
||||||
ROW_LIMIT = 10 # 每行最多10个
|
|
||||||
blank_ref = next(iter(refs.values())) # 取一个参考图的尺寸
|
|
||||||
blank_img = Image.new(
|
|
||||||
"RGB", (blank_ref.shape[1], blank_ref.shape[0]), (200, 200, 200)
|
|
||||||
) # 灰色占位图
|
|
||||||
|
|
||||||
combined_images = []
|
|
||||||
row_images = []
|
|
||||||
|
|
||||||
row_width = 0
|
|
||||||
max_height = 0
|
|
||||||
|
|
||||||
for idx in range(len(rois)):
|
|
||||||
roi_img = Image.fromarray(rois[idx]).convert("RGB")
|
|
||||||
ref_name = match_idx.get(idx)
|
|
||||||
if ref_name:
|
|
||||||
ref_img = Image.fromarray(refs[ref_name]).convert("RGB")
|
|
||||||
else:
|
|
||||||
ref_img = blank_img.copy()
|
|
||||||
|
|
||||||
combined_width = roi_img.width + ref_img.width
|
|
||||||
combined_height = max(roi_img.height, ref_img.height)
|
|
||||||
|
|
||||||
combined = Image.new("RGB", (combined_width, combined_height), (255, 255, 255))
|
|
||||||
combined.paste(roi_img, (0, 0))
|
|
||||||
combined.paste(ref_img, (roi_img.width, 0))
|
|
||||||
|
|
||||||
draw = ImageDraw.Draw(combined)
|
|
||||||
|
|
||||||
font = ImageFont.truetype("msyh.ttc", 20)
|
|
||||||
|
|
||||||
label = f"ROI {idx}: {ref_name if ref_name else 'Unknown'}"
|
|
||||||
draw.text((5, 5), label, fill=(255, 0, 0), font=font)
|
|
||||||
combined = ImageOps.expand(combined, border=2, fill=(0, 0, 0))
|
|
||||||
row_images.append(combined)
|
|
||||||
row_width += combined_width
|
|
||||||
max_height = max(max_height, combined_height)
|
|
||||||
|
|
||||||
if len(row_images) == ROW_LIMIT or idx == len(rois) - 1:
|
|
||||||
row_img = Image.new("RGB", (row_width, max_height), (255, 255, 255))
|
|
||||||
x_offset = 0
|
|
||||||
for img in row_images:
|
|
||||||
row_img.paste(img, (x_offset, 0))
|
|
||||||
x_offset += img.width
|
|
||||||
combined_images.append(row_img)
|
|
||||||
row_images = []
|
|
||||||
row_width = 0
|
|
||||||
max_height = 0
|
|
||||||
|
|
||||||
total_width = max(img.width for img in combined_images)
|
|
||||||
total_height = sum(img.height for img in combined_images)
|
|
||||||
|
|
||||||
final_img = Image.new("RGB", (total_width, total_height), (255, 255, 255))
|
|
||||||
y_offset = 0
|
|
||||||
for img in combined_images:
|
|
||||||
final_img.paste(img, (0, y_offset))
|
|
||||||
y_offset += img.height
|
|
||||||
|
|
||||||
final_img.save("all_matches.png") # 或 final_img.save("all_matches.png")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
refs = load_references(REF_TABLE_JSON, ICON_DIR)
|
|
||||||
rois = detect_and_crop(IMG_PATH)
|
|
||||||
res, match_idx = match_ssim(rois, refs)
|
|
||||||
|
|
||||||
print("最终识别结果:", res)
|
|
||||||
display_matches(rois, match_idx, refs)
|
|
||||||
print(datetime.now() - now)
|
|
130
depot_test/训练.py
Normal file
130
depot_test/训练.py
Normal file
|
@ -0,0 +1,130 @@
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
# 1. 网络定义
|
||||||
|
class SiameseNetwork(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(SiameseNetwork, self).__init__()
|
||||||
|
self.cnn = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 32, kernel_size=5), # (3,110,110) -> (32,106,106)
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.MaxPool2d(2), # -> (32,53,53)
|
||||||
|
nn.Conv2d(32, 64, kernel_size=5), # -> (64,49,49)
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.MaxPool2d(2), # -> (64,24,24)
|
||||||
|
)
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(64 * 24 * 24, 512), nn.ReLU(), nn.Linear(512, 128)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_once(self, x):
|
||||||
|
x = self.cnn(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
return self.fc(x)
|
||||||
|
|
||||||
|
def forward(self, input1, input2):
|
||||||
|
output1 = self.forward_once(input1)
|
||||||
|
output2 = self.forward_once(input2)
|
||||||
|
return output1, output2
|
||||||
|
|
||||||
|
|
||||||
|
# 2. Contrastive Loss
|
||||||
|
class ContrastiveLoss(nn.Module):
|
||||||
|
def __init__(self, margin=2.0):
|
||||||
|
super(ContrastiveLoss, self).__init__()
|
||||||
|
self.margin = margin
|
||||||
|
|
||||||
|
def forward(self, out1, out2, label):
|
||||||
|
dist = F.pairwise_distance(out1, out2)
|
||||||
|
loss = label * torch.pow(dist, 2) + (1 - label) * torch.pow(
|
||||||
|
torch.clamp(self.margin - dist, min=0.0), 2
|
||||||
|
)
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
|
||||||
|
# 3. Dataset 构造器
|
||||||
|
class SiameseDataset(Dataset):
|
||||||
|
def __init__(self, folder_path, transform=None):
|
||||||
|
self.folder_path = folder_path
|
||||||
|
self.classes = os.listdir(folder_path)
|
||||||
|
self.transform = transform or transforms.ToTensor()
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
class1 = random.choice(self.classes)
|
||||||
|
class2 = (
|
||||||
|
class1
|
||||||
|
if random.random() < 0.5
|
||||||
|
else random.choice([c for c in self.classes if c != class1])
|
||||||
|
)
|
||||||
|
|
||||||
|
img1_path = os.path.join(
|
||||||
|
self.folder_path,
|
||||||
|
class1,
|
||||||
|
random.choice(os.listdir(os.path.join(self.folder_path, class1))),
|
||||||
|
)
|
||||||
|
img2_path = os.path.join(
|
||||||
|
self.folder_path,
|
||||||
|
class2,
|
||||||
|
random.choice(os.listdir(os.path.join(self.folder_path, class2))),
|
||||||
|
)
|
||||||
|
|
||||||
|
img1 = Image.open(img1_path).convert("RGB").resize((110, 110))
|
||||||
|
img2 = Image.open(img2_path).convert("RGB").resize((110, 110))
|
||||||
|
|
||||||
|
if self.transform:
|
||||||
|
img1 = self.transform(img1)
|
||||||
|
img2 = self.transform(img2)
|
||||||
|
|
||||||
|
label = 1.0 if class1 == class2 else 0.0
|
||||||
|
return img1, img2, torch.tensor([label], dtype=torch.float32)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 5000
|
||||||
|
|
||||||
|
|
||||||
|
# 4. 训练入口
|
||||||
|
def train():
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(device)
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = SiameseDataset("dataset/train", transform=transform)
|
||||||
|
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
|
||||||
|
|
||||||
|
model = SiameseNetwork().to(device)
|
||||||
|
criterion = ContrastiveLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
for epoch in range(100):
|
||||||
|
total_loss = 0
|
||||||
|
for img1, img2, label in tqdm(dataloader, desc=f"Epoch {epoch + 1}"):
|
||||||
|
img1, img2, label = img1.to(device), img2.to(device), label.to(device)
|
||||||
|
|
||||||
|
out1, out2 = model(img1, img2)
|
||||||
|
loss = criterion(out1, out2, label)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader):.4f}")
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), "siamese_model.pth")
|
||||||
|
print("Model saved as siamese_model.pth")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train()
|
74
depot_test/预测.py
Normal file
74
depot_test/预测.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
from 训练 import SiameseNetwork # 确保和训练脚本在同一目录
|
||||||
|
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
def load_model(model_path, device):
|
||||||
|
model = SiameseNetwork().to(device)
|
||||||
|
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
# 读取训练集中每个类别的图像(作为支持集)
|
||||||
|
def load_train_class_images(train_dir, transform, device):
|
||||||
|
class_images = {}
|
||||||
|
for class_name in os.listdir(train_dir):
|
||||||
|
class_path = os.path.join(train_dir, class_name)
|
||||||
|
if not os.path.isdir(class_path):
|
||||||
|
continue
|
||||||
|
images = []
|
||||||
|
for img_name in os.listdir(class_path):
|
||||||
|
img_path = os.path.join(class_path, img_name)
|
||||||
|
image = Image.open(img_path).convert("RGB").resize((110, 110))
|
||||||
|
image = transform(image).unsqueeze(0).to(device) # shape: (1,3,110,110)
|
||||||
|
images.append(image)
|
||||||
|
if images:
|
||||||
|
class_images[class_name] = images
|
||||||
|
return class_images
|
||||||
|
|
||||||
|
# 推理函数:返回最相似的类别名
|
||||||
|
def predict(model, test_img, class_images):
|
||||||
|
min_dist = float("inf")
|
||||||
|
predicted_class = None
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for class_name, ref_images in class_images.items():
|
||||||
|
for ref_img in ref_images:
|
||||||
|
out1, out2 = model(test_img, ref_img)
|
||||||
|
dist = F.pairwise_distance(out1, out2)
|
||||||
|
if dist.item() < min_dist:
|
||||||
|
min_dist = dist.item()
|
||||||
|
predicted_class = class_name
|
||||||
|
return predicted_class, min_dist
|
||||||
|
|
||||||
|
# 主推理流程
|
||||||
|
def infer():
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = load_model("siamese_model.pth", device)
|
||||||
|
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
])
|
||||||
|
|
||||||
|
train_dir = "dataset/train"
|
||||||
|
test_dir = "dataset/test"
|
||||||
|
class_images = load_train_class_images(train_dir, transform, device)
|
||||||
|
|
||||||
|
print("开始测试...")
|
||||||
|
for class_name in os.listdir(test_dir):
|
||||||
|
class_path = os.path.join(test_dir, class_name)
|
||||||
|
for img_name in os.listdir(class_path):
|
||||||
|
img_path = os.path.join(class_path, img_name)
|
||||||
|
img = Image.open(img_path).convert("RGB").resize((110, 110))
|
||||||
|
img_tensor = transform(img).unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
predicted_class, dist = predict(model, img_tensor, class_images)
|
||||||
|
print(f"Test Image: {img_name} | True: {class_name} | Predicted: {predicted_class} | Distance: {dist:.4f}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
infer()
|
Loading…
Add table
Add a link
Reference in a new issue