上传一些测试文件
This commit is contained in:
parent
cee44523c6
commit
44782a3117
10 changed files with 912 additions and 689 deletions
BIN
depot_test/output/matches_knn_only.png
(Stored with Git LFS)
Normal file
BIN
depot_test/output/matches_knn_only.png
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
depot_test/stitched_image_multi.png
(Stored with Git LFS)
Normal file
BIN
depot_test/stitched_image_multi.png
(Stored with Git LFS)
Normal file
Binary file not shown.
324
depot_test/仓库识别MobileNetV3 KNN copy.py
Normal file
324
depot_test/仓库识别MobileNetV3 KNN copy.py
Normal file
|
@ -0,0 +1,324 @@
|
|||
import glob
|
||||
import os
|
||||
import re # 用于解析文件名中的标签
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision import models, transforms
|
||||
from torchvision.datasets import ImageFolder
|
||||
from tqdm import tqdm # 用于显示进度条
|
||||
|
||||
# --- 配置参数 ---
|
||||
TRAIN_DIR = "训练集"
|
||||
VAL_TEST_DIR = "测试集" # 根据你的描述,验证集和测试集是同一个目录,文件格式相同
|
||||
IMAGE_SIZE = 224 # MobileNetV3 的标准输入大小
|
||||
BATCH_SIZE = 32
|
||||
NUM_EPOCHS = 20 # 可以根据需要调整
|
||||
LEARNING_RATE = 0.001 # 初始学习率
|
||||
SAVE_MODEL_PATH = "mobilenetv3_small_finetuned.pth"
|
||||
|
||||
# 自动检测设备
|
||||
device = torch.device("cuda")
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
# --- 数据预处理和增强 ---
|
||||
# ImageNet 标准均值和标准差
|
||||
mean = [0.485, 0.456, 0.406]
|
||||
std = [0.229, 0.224, 0.225]
|
||||
|
||||
# 训练集数据增强和预处理
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.RandomResizedCrop(IMAGE_SIZE), # 随机裁剪并缩放
|
||||
transforms.RandomHorizontalFlip(), # 随机水平翻转
|
||||
transforms.ToTensor(), # 转换为 Tensor
|
||||
transforms.Normalize(mean, std), # 标准化
|
||||
]
|
||||
)
|
||||
|
||||
# 验证/测试集数据预处理 (不需要数据增强,只需要中心裁剪和标准化)
|
||||
val_test_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(int(IMAGE_SIZE * 256 / 224)), # 缩放到较大尺寸
|
||||
transforms.CenterCrop(IMAGE_SIZE), # 中心裁剪到目标尺寸
|
||||
transforms.ToTensor(), # 转换为 Tensor
|
||||
transforms.Normalize(mean, std), # 标准化
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# --- 自定义测试集 Dataset ---
|
||||
# 需要一个自定义 Dataset 来处理 "测试集/{idx}_{label}.png" 这种文件命名格式
|
||||
class CustomValTestDataset(Dataset):
|
||||
def __init__(self, root_dir, class_to_idx, transform=None):
|
||||
"""
|
||||
Args:
|
||||
root_dir (string): 数据集根目录 (e.g., "测试集").
|
||||
class_to_idx (dict): 从类别名称到索引的映射,与训练集一致。
|
||||
transform (callable, optional): 应用于图像的转换.
|
||||
"""
|
||||
self.root_dir = root_dir
|
||||
self.transform = transform
|
||||
self.class_to_idx = class_to_idx
|
||||
self.idx_to_class = {v: k for k, v in class_to_idx.items()}
|
||||
self.image_files = []
|
||||
self.labels = []
|
||||
|
||||
# 遍历目录下的所有png文件
|
||||
filepaths = glob.glob(os.path.join(root_dir, "*.png"))
|
||||
|
||||
# 解析文件名,提取标签
|
||||
pattern = re.compile(r"^\d+_([^_]+)\.png$") # 匹配 数字_标签.png
|
||||
|
||||
for filepath in filepaths:
|
||||
filename = os.path.basename(filepath)
|
||||
match = pattern.match(filename)
|
||||
if match:
|
||||
label_name = match.group(1)
|
||||
if label_name in self.class_to_idx:
|
||||
self.image_files.append(filepath)
|
||||
self.labels.append(self.class_to_idx[label_name])
|
||||
else:
|
||||
print(
|
||||
f"警告: 文件 '{filename}' 中的标签 '{label_name}' 不在训练集的类别中,将跳过。"
|
||||
)
|
||||
|
||||
print(f"加载了 {len(self.image_files)} 张验证/测试图片。")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_files)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if torch.is_tensor(idx):
|
||||
idx = idx.tolist()
|
||||
|
||||
img_path = self.image_files[idx]
|
||||
label = self.labels[idx]
|
||||
|
||||
# 打开图像,确保是 RGB (处理可能的灰度图)
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, label
|
||||
|
||||
|
||||
# --- 加载数据 ---
|
||||
# 使用 ImageFolder 加载训练集,它会自动从目录名解析类别
|
||||
if not os.path.exists(TRAIN_DIR):
|
||||
print(
|
||||
f"错误: 训练集目录 '{TRAIN_DIR}' 不存在。请创建该目录并放入分类好的图片子目录。"
|
||||
)
|
||||
exit()
|
||||
|
||||
if not os.path.exists(VAL_TEST_DIR):
|
||||
print(f"错误: 验证/测试集目录 '{VAL_TEST_DIR}' 不存在。请创建该目录并放入图片。")
|
||||
exit()
|
||||
|
||||
|
||||
train_dataset = ImageFolder(TRAIN_DIR, transform=train_transforms)
|
||||
num_classes = len(train_dataset.classes)
|
||||
class_to_idx = train_dataset.class_to_idx # 获取类别到索引的映射
|
||||
|
||||
print(f"从训练集检测到 {num_classes} 个类别: {train_dataset.classes}")
|
||||
|
||||
# 使用自定义 Dataset 加载验证/测试集
|
||||
val_test_dataset = CustomValTestDataset(
|
||||
VAL_TEST_DIR, class_to_idx, transform=val_test_transforms
|
||||
)
|
||||
|
||||
# 创建 DataLoader
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=BATCH_SIZE, shuffle=True
|
||||
) # num_workers 根据你的机器性能调整
|
||||
val_test_loader = DataLoader(val_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
# --- 加载预训练的 MobileNetV3-Small 模型 ---
|
||||
# 使用 weights 参数来指定预训练权重
|
||||
# MobileNetV3_Small_Weights.IMAGENET1K_V1 是在 ImageNet 上预训练的权重
|
||||
try:
|
||||
model = models.mobilenet_v3_small(
|
||||
weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
|
||||
)
|
||||
print("成功加载预训练的 MobileNetV3-Small 模型 (ImageNet weights)。")
|
||||
except Exception as e:
|
||||
print(f"加载预训练模型失败: {e}")
|
||||
print("尝试加载不带权重的模型...")
|
||||
model = models.mobilenet_v3_small(weights=None)
|
||||
|
||||
|
||||
# --- 修改全连接层以匹配新的类别数量 ---
|
||||
# MobileNetV3 的分类器是 model.classifier
|
||||
# 最后一个线性层是 classifier[-1]
|
||||
num_ftrs = model.classifier[-1].in_features
|
||||
# 替换掉原来的全连接层
|
||||
model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
|
||||
|
||||
model = model.to(device)
|
||||
|
||||
# --- 定义损失函数和优化器 ---
|
||||
criterion = nn.CrossEntropyLoss() # 交叉熵损失适用于分类问题
|
||||
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) # Adam 优化器
|
||||
|
||||
# 可选:学习率调度器,帮助调整学习率
|
||||
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) # 每7个epoch降低学习率
|
||||
|
||||
# --- 训练和评估函数 ---
|
||||
|
||||
|
||||
def train_epoch(model, train_loader, criterion, optimizer, device):
|
||||
model.train() # 设置模型为训练模式
|
||||
running_loss = 0.0
|
||||
correct_predictions = 0
|
||||
total_samples = 0
|
||||
|
||||
# 使用 tqdm 显示进度条
|
||||
for inputs, labels in tqdm(train_loader, desc="训练中"):
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
# 梯度清零
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 前向传播
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# 反向传播和优化
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 统计
|
||||
running_loss += loss.item() * inputs.size(0) # 累加 batch loss * batch size
|
||||
_, predicted = torch.max(outputs, 1) # 获取预测结果
|
||||
correct_predictions += (predicted == labels).sum().item()
|
||||
total_samples += labels.size(0)
|
||||
|
||||
epoch_loss = running_loss / total_samples
|
||||
epoch_accuracy = correct_predictions / total_samples
|
||||
return epoch_loss, epoch_accuracy
|
||||
|
||||
|
||||
def evaluate(model, data_loader, criterion, device, desc="评估中"):
|
||||
model.eval() # 设置模型为评估模式
|
||||
running_loss = 0.0
|
||||
correct_predictions = 0
|
||||
total_samples = 0
|
||||
|
||||
# 在评估阶段不计算梯度
|
||||
with torch.no_grad():
|
||||
# 使用 tqdm 显示进度条
|
||||
for inputs, labels in tqdm(data_loader, desc=desc):
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
# 前向传播
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# 统计
|
||||
running_loss += loss.item() * inputs.size(0)
|
||||
_, predicted = torch.max(outputs, 1)
|
||||
correct_predictions += (predicted == labels).sum().item()
|
||||
total_samples += labels.size(0)
|
||||
|
||||
epoch_loss = running_loss / total_samples
|
||||
epoch_accuracy = correct_predictions / total_samples
|
||||
return epoch_loss, epoch_accuracy
|
||||
|
||||
|
||||
# --- 训练循环 ---
|
||||
best_val_accuracy = 0.0
|
||||
|
||||
print("\n开始训练...")
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
print(f"\n--- Epoch {epoch + 1}/{NUM_EPOCHS} ---")
|
||||
|
||||
# 训练阶段
|
||||
train_loss, train_accuracy = train_epoch(
|
||||
model, train_loader, criterion, optimizer, device
|
||||
)
|
||||
print(
|
||||
f"Epoch {epoch + 1} 训练 Loss: {train_loss:.4f}, 准确率: {train_accuracy:.4f}"
|
||||
)
|
||||
|
||||
# 可选:学习率调度
|
||||
# if scheduler is not None:
|
||||
# scheduler.step()
|
||||
|
||||
# 验证/测试阶段
|
||||
val_loss, val_accuracy = evaluate(
|
||||
model, val_test_loader, criterion, device, desc="验证/测试中"
|
||||
)
|
||||
print(
|
||||
f"Epoch {epoch + 1} 验证/测试 Loss: {val_loss:.4f}, 准确率: {val_accuracy:.4f}"
|
||||
)
|
||||
|
||||
# 保存最优模型
|
||||
if val_accuracy > best_val_accuracy:
|
||||
best_val_accuracy = val_accuracy
|
||||
torch.save(model.state_dict(), SAVE_MODEL_PATH)
|
||||
print(
|
||||
f"保存了验证/测试集上最优的模型 (准确率: {best_val_accuracy:.4f}) 到 {SAVE_MODEL_PATH}"
|
||||
)
|
||||
|
||||
print("\n训练完成!")
|
||||
print(f"在验证/测试集上的最高准确率: {best_val_accuracy:.4f}")
|
||||
|
||||
# --- 可选: 加载并测试最优模型 ---
|
||||
# 加载保存的最优模型进行最终测试 (如果验证/测试集是同一个,这就是最终结果)
|
||||
print(f"\n加载最优模型 '{SAVE_MODEL_PATH}' 进行最终评估...")
|
||||
loaded_model = models.mobilenet_v3_small(weights=None) # 先加载一个空的模型结构
|
||||
loaded_model.classifier[-1] = nn.Linear(
|
||||
loaded_model.classifier[-1].in_features, num_classes
|
||||
) # 修改分类器
|
||||
loaded_model.load_state_dict(
|
||||
torch.load(SAVE_MODEL_PATH, map_location=device)
|
||||
) # 加载权重
|
||||
loaded_model = loaded_model.to(device)
|
||||
|
||||
final_test_loss, final_test_accuracy = evaluate(
|
||||
loaded_model, val_test_loader, criterion, device, desc="最终测试中"
|
||||
)
|
||||
print(
|
||||
f"\n最终测试 Loss: {final_test_loss:.4f}, 最终测试准确率: {final_test_accuracy:.4f}"
|
||||
)
|
||||
|
||||
|
||||
# --- 可选: 预测单张图片 ---
|
||||
|
||||
# 假设你想预测一张名为 '测试集/some_image_X_label.png' 的图片
|
||||
def predict_single_image(image_path, model, class_to_idx, device, transform):
|
||||
model.eval()
|
||||
idx_to_class = {v: k for k, v in class_to_idx.items()}
|
||||
|
||||
try:
|
||||
img = Image.open(image_path).convert("RGB")
|
||||
img = transform(img).unsqueeze(0).to(device) # 添加 batch 维度并移动到设备
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(img)
|
||||
probabilities = torch.softmax(outputs, dim=1)[0] # 获取概率分布
|
||||
_, predicted_idx = torch.max(probabilities, 0) # 获取最高概率的索引
|
||||
predicted_label = idx_to_class[predicted_idx.item()]
|
||||
confidence = probabilities[predicted_idx].item()
|
||||
|
||||
print(f"\n预测图片: {image_path}")
|
||||
print(f"预测类别: {predicted_label}, 置信度: {confidence:.4f}")
|
||||
return predicted_label, confidence
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"错误: 图片文件 '{image_path}' 未找到。")
|
||||
return None, None
|
||||
except Exception as e:
|
||||
print(f"预测图片时发生错误: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
# # 示例预测 (取消注释以使用)
|
||||
example_image_path = r"测试集\27_基础作战记录.png" # 替换为你测试集中的实际文件路径
|
||||
predict_single_image(
|
||||
example_image_path, loaded_model, class_to_idx, device, val_test_transforms
|
||||
)
|
378
depot_test/仓库识别MobileNetV3 KNN.py
Normal file
378
depot_test/仓库识别MobileNetV3 KNN.py
Normal file
|
@ -0,0 +1,378 @@
|
|||
import json
|
||||
import lzma
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from torchvision import models
|
||||
|
||||
CROP_SIZE = 130
|
||||
BORDER = 26
|
||||
size = CROP_SIZE * 2 - BORDER * 2
|
||||
|
||||
|
||||
# 定义特征提取器
|
||||
model = models.mobilenet_v3_small(weights="DEFAULT")
|
||||
|
||||
features_part = model.features
|
||||
avgpool = torch.nn.AdaptiveAvgPool2d(1)
|
||||
classifier_part_excluding_last = torch.nn.Sequential(
|
||||
*list(model.classifier.children())[:-1]
|
||||
)
|
||||
|
||||
feature_extractor = torch.nn.Sequential(
|
||||
features_part,
|
||||
avgpool,
|
||||
torch.nn.Flatten(start_dim=1),
|
||||
classifier_part_excluding_last,
|
||||
)
|
||||
feature_extractor.eval() # 切换到评估模式
|
||||
|
||||
|
||||
def 提取特征点(模板):
|
||||
"""使用MobileNetV3提取特征 (PyTorch版)"""
|
||||
# 将输入图像从BGR转换为RGB
|
||||
img_rgb = cv2.cvtColor(模板, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 定义图像预处理流程
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.ToPILImage(), # 转换为PIL图像
|
||||
transforms.Resize(250), # 调整大小为224x224
|
||||
transforms.ToTensor(), # 转换为Tensor
|
||||
transforms.Normalize( # 归一化
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# 预处理图像
|
||||
img_tensor = preprocess(img_rgb)
|
||||
img_tensor = img_tensor.unsqueeze(0) # 增加batch维度
|
||||
|
||||
# 提取特征
|
||||
with torch.no_grad():
|
||||
features = feature_extractor(img_tensor)
|
||||
|
||||
# 将特征展平为一维
|
||||
features = features.flatten().numpy()
|
||||
|
||||
return features
|
||||
|
||||
|
||||
class DepotMatcher:
|
||||
def __init__(
|
||||
self,
|
||||
ref_table_json="./ArknightsGameData/zh_CN/gamedata/excel/item_table.json",
|
||||
icon_dir="./ArknightsResource/items/",
|
||||
ref_dir="depot_test/output/test/origin",
|
||||
roi_dir="depot_test/output/test/result",
|
||||
img_path=r"depot_test\stitched_image_multi.png",
|
||||
):
|
||||
# 初始化路径配置
|
||||
self.REF_DIR = ref_dir
|
||||
self.ROI_DIR = roi_dir
|
||||
self.IMG_PATH = img_path
|
||||
self.REF_TABLE_JSON = ref_table_json
|
||||
self.ICON_DIR = icon_dir
|
||||
|
||||
# 初始化算法参数
|
||||
self.HOUGH_PARAMS = dict(
|
||||
dp=5, minDist=230, param1=50, param2=30, minRadius=90, maxRadius=100
|
||||
)
|
||||
self.CROP_SIZE = 130
|
||||
self.BORDER = 26
|
||||
|
||||
# 运行时数据存储
|
||||
self.refs = None
|
||||
self.rois = []
|
||||
self.knn_results = []
|
||||
self.knn_model = None
|
||||
|
||||
def load_references(self):
|
||||
"""加载物品图标参考图(保留彩色)"""
|
||||
data = json.load(open(self.REF_TABLE_JSON, encoding="utf-8"))
|
||||
self.refs = {}
|
||||
size = self.CROP_SIZE * 2 - self.BORDER * 2
|
||||
|
||||
# 首先收集所有带有sortId的物品
|
||||
items_with_sort = []
|
||||
for item in data.get("items", {}).values():
|
||||
if item.get("classifyType") not in {"NORMAL", "CONSUME", "MATERIAL"}:
|
||||
continue
|
||||
|
||||
path = os.path.join(self.ICON_DIR, f"{item['iconId']}.png")
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
|
||||
# 保留彩色图像
|
||||
im = Image.open(path).resize((size, size))
|
||||
items_with_sort.append(
|
||||
{
|
||||
"name": item["name"],
|
||||
"array": np.array(im),
|
||||
"sortId": item.get("sortId", 0),
|
||||
}
|
||||
)
|
||||
|
||||
# 按sortId排序
|
||||
items_with_sort.sort(key=lambda x: x["sortId"])
|
||||
|
||||
# 创建最终的refs字典
|
||||
for item in items_with_sort:
|
||||
self.refs[item["name"]] = item["array"]
|
||||
|
||||
print(f"已加载 {len(self.refs)} 个参考图 (按sortId排序)")
|
||||
# 保存训练集图像
|
||||
os.makedirs("训练集", exist_ok=True)
|
||||
for name, array in self.refs.items():
|
||||
os.makedirs(f"训练集/{name}", exist_ok=True)
|
||||
path = os.path.join(f"训练集/{name}", f"{name}.png")
|
||||
im = Image.fromarray(array)
|
||||
cropped_im = im.crop((50, 30, 160, 140))
|
||||
cropped_im.save(path)
|
||||
|
||||
return self
|
||||
|
||||
def _process_circle(self, idx, circle, img):
|
||||
"""处理单个圆形区域,返回彩色图像数据"""
|
||||
x, y, r = circle
|
||||
# 裁剪包含圆形的更大区域
|
||||
crop = img[
|
||||
max(0, y - self.CROP_SIZE) : min(img.shape[0], y + self.CROP_SIZE),
|
||||
max(0, x - self.CROP_SIZE) : min(img.shape[1], x + self.CROP_SIZE),
|
||||
]
|
||||
|
||||
# 提取核心的彩色ROI区域
|
||||
color_roi = crop[self.BORDER : -self.BORDER, self.BORDER : -self.BORDER]
|
||||
|
||||
# 提取用于匹配的彩色区域
|
||||
color_sec = color_roi
|
||||
|
||||
return idx, color_sec, color_roi
|
||||
|
||||
def detect_and_crop(self):
|
||||
"""检测并裁剪截图区域"""
|
||||
img = cv2.imread(self.IMG_PATH)
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, **self.HOUGH_PARAMS)
|
||||
|
||||
# 处理检测到的圆形
|
||||
circles = np.round(circles[0]).astype(int)
|
||||
circles = sorted(circles, key=lambda c: (c[0], c[1])) # 按坐标排序
|
||||
|
||||
self.rois = []
|
||||
for idx, circle in enumerate(circles):
|
||||
result = self._process_circle(idx, circle, img)
|
||||
self.rois.append(result)
|
||||
|
||||
return self
|
||||
|
||||
def 训练并保存knn模型(self, images, labels, filename):
|
||||
"""训练并保存KNN模型"""
|
||||
knn_classifier = KNeighborsClassifier(
|
||||
weights="distance", n_neighbors=1, n_jobs=1
|
||||
)
|
||||
knn_classifier.fit(images, labels)
|
||||
|
||||
with lzma.open(filename, "wb") as f:
|
||||
pickle.dump(knn_classifier, f)
|
||||
|
||||
return knn_classifier
|
||||
|
||||
def 训练knn模型(self, 模型保存路径="depot_knn_model.xz"):
|
||||
"""训练并保存KNN模型(使用彩色图像)"""
|
||||
|
||||
# 准备训练数据
|
||||
images = []
|
||||
labels = []
|
||||
|
||||
for name, img_array in self.refs.items():
|
||||
features = 提取特征点(img_array)
|
||||
images.append(features)
|
||||
labels.append(name)
|
||||
|
||||
# 训练模型
|
||||
self.knn_model = self.训练并保存knn模型(images, labels, 模型保存路径)
|
||||
print(f"KNN模型训练完成,已保存到: {模型保存路径}")
|
||||
return self
|
||||
|
||||
def 使用knn预测(self, 测试图像):
|
||||
features = 提取特征点(测试图像)
|
||||
|
||||
# 预测
|
||||
预测结果 = self.knn_model.predict([features])
|
||||
|
||||
return 预测结果[0]
|
||||
|
||||
def match_items_knn_only(
|
||||
self,
|
||||
knn_model_path="depot_knn_model.xz",
|
||||
):
|
||||
"""仅使用KNN方法进行匹配"""
|
||||
self.knn_results = []
|
||||
newstart = None
|
||||
with lzma.open(knn_model_path, "rb") as f:
|
||||
self.knn_model = pickle.load(f)
|
||||
|
||||
os.makedirs("测试集", exist_ok=True)
|
||||
|
||||
for idx, color_sec_np, _ in self.rois:
|
||||
# KNN预测
|
||||
roi_gray = cv2.cvtColor(color_sec_np, cv2.COLOR_RGB2GRAY)
|
||||
knn_name = self.使用knn预测(roi_gray)
|
||||
|
||||
self.knn_results.append((idx, knn_name))
|
||||
os.makedirs(f"测试集/{knn_name}", exist_ok=True)
|
||||
Image.fromarray(cv2.cvtColor(color_sec_np, cv2.COLOR_BGR2RGB)).crop(
|
||||
(50, 30, 160, 140)
|
||||
).save(os.path.join(f"测试集/{knn_name}", f"{idx}_{knn_name}.png"))
|
||||
# 更新newstart逻辑
|
||||
newstart = knn_name
|
||||
|
||||
print(f"ROI {idx}: Hog+Knn={knn_name}, newstart={newstart}")
|
||||
|
||||
return self
|
||||
|
||||
def display_results(self):
|
||||
"""可视化匹配结果"""
|
||||
ROW_LIMIT = 9
|
||||
|
||||
# 获取一个参考图像的尺寸作为空白图像的基础
|
||||
blank_ref_np = next(iter(self.refs.values()))
|
||||
blank_img_pil = Image.new(
|
||||
"RGB", (blank_ref_np.shape[1], blank_ref_np.shape[0]), (200, 200, 200)
|
||||
)
|
||||
|
||||
combined_images = []
|
||||
current_row_images = []
|
||||
current_row_width = 0
|
||||
max_row_height = 0
|
||||
|
||||
for idx, color_sec_np, color_roi_data in self.rois:
|
||||
color_roi_data = Image.fromarray(
|
||||
cv2.cvtColor(color_roi_data, cv2.COLOR_BGR2RGB)
|
||||
)
|
||||
color_sec_np = Image.fromarray(
|
||||
cv2.cvtColor(color_sec_np, cv2.COLOR_BGR2RGB)
|
||||
)
|
||||
|
||||
# 获取KNN匹配结果
|
||||
k_res_details = next(
|
||||
(d for d in getattr(self, "knn_results", []) if d[0] == idx), None
|
||||
)
|
||||
k_res_name = k_res_details[1] if k_res_details else None
|
||||
k_ref_img = (
|
||||
Image.fromarray(self.refs[k_res_name]).convert("RGB")
|
||||
if k_res_name and k_res_name in self.refs
|
||||
else blank_img_pil.copy()
|
||||
)
|
||||
|
||||
# 计算组合尺寸
|
||||
combined_width = color_roi_data.width + color_sec_np.width + k_ref_img.width
|
||||
|
||||
combined_height = max(
|
||||
color_roi_data.height,
|
||||
color_sec_np.height,
|
||||
k_ref_img.height,
|
||||
)
|
||||
|
||||
# 创建组合图像
|
||||
combined = Image.new(
|
||||
"RGB", (combined_width, combined_height), (255, 255, 255)
|
||||
)
|
||||
x_offset = 0
|
||||
|
||||
# 粘贴各个部分
|
||||
combined.paste(color_roi_data, (x_offset, 0))
|
||||
x_offset += color_roi_data.width
|
||||
|
||||
combined.paste(color_sec_np, (x_offset, 0))
|
||||
x_offset += color_sec_np.width
|
||||
|
||||
combined.paste(k_ref_img, (x_offset, 0))
|
||||
x_offset += k_ref_img.width
|
||||
|
||||
# 添加标注
|
||||
draw = ImageDraw.Draw(combined)
|
||||
font = ImageFont.truetype("msyh.ttc", 16)
|
||||
|
||||
label = f"ROI {idx}\nHog+Knn: {k_res_name or 'None'}"
|
||||
|
||||
text_color = (0, 0, 0)
|
||||
|
||||
draw.text(
|
||||
(color_roi_data.width, color_sec_np.height),
|
||||
label,
|
||||
fill=text_color,
|
||||
font=font,
|
||||
)
|
||||
|
||||
# 添加边框
|
||||
combined_bordered = ImageOps.expand(combined, border=2, fill=(0, 0, 0))
|
||||
current_row_images.append(combined_bordered)
|
||||
current_row_width += combined_bordered.width
|
||||
max_row_height = max(max_row_height, combined_bordered.height)
|
||||
|
||||
# 检查是否需要换行
|
||||
if len(current_row_images) == ROW_LIMIT:
|
||||
row_img = Image.new(
|
||||
"RGB", (current_row_width, max_row_height), (255, 255, 255)
|
||||
)
|
||||
x = 0
|
||||
for img in current_row_images:
|
||||
row_img.paste(img, (x, 0))
|
||||
x += img.width
|
||||
combined_images.append(row_img)
|
||||
current_row_images = []
|
||||
current_row_width = 0
|
||||
max_row_height = 0
|
||||
|
||||
# 处理最后一行
|
||||
if current_row_images:
|
||||
row_img = Image.new(
|
||||
"RGB", (current_row_width, max_row_height), (255, 255, 255)
|
||||
)
|
||||
x = 0
|
||||
for img in current_row_images:
|
||||
row_img.paste(img, (x, 0))
|
||||
x += img.width
|
||||
combined_images.append(row_img)
|
||||
|
||||
# 生成最终图像
|
||||
if combined_images:
|
||||
total_height = sum(img.height for img in combined_images)
|
||||
max_width = max(img.width for img in combined_images)
|
||||
final_img = Image.new("RGB", (max_width, total_height), (255, 255, 255))
|
||||
|
||||
y = 0
|
||||
for img in combined_images:
|
||||
final_img.paste(img, (0, y))
|
||||
y += img.height
|
||||
|
||||
output_path = "depot_test/output/matches_knn_only.png"
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
final_img.save(output_path)
|
||||
print(f"结果图像已保存至: {output_path}")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用示例
|
||||
matcher = DepotMatcher()
|
||||
matcher.load_references()
|
||||
matcher.训练knn模型()
|
||||
from datetime import datetime
|
||||
|
||||
now_time = datetime.now()
|
||||
|
||||
matcher.detect_and_crop()
|
||||
matcher.match_items_knn_only()
|
||||
print(datetime.now() - now_time)
|
||||
matcher.display_results()
|
|
@ -1,188 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from multiprocessing import Pool
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
||||
from skimage.metrics import structural_similarity
|
||||
|
||||
now = datetime.now()
|
||||
# 配置路径
|
||||
REF_DIR = r"depot_test\output/test/origin"
|
||||
ROI_DIR = r"depot_test\output/test/result"
|
||||
IMG_PATH = r"depot_test\result_refined.png"
|
||||
REF_TABLE_JSON = "./ArknightsGameData/zh_CN/gamedata/excel/item_table.json"
|
||||
ICON_DIR = "./ArknightsResource/items/"
|
||||
|
||||
# SSIM阈值
|
||||
SSIM_THRESHOLD = 0.01
|
||||
|
||||
# 圆检测参数
|
||||
HOUGH_PARAMS = dict(
|
||||
dp=5, minDist=230, param1=50, param2=30, minRadius=90, maxRadius=100
|
||||
)
|
||||
CROP_SIZE = 130
|
||||
BORDER = 26
|
||||
SECONDARY_SLICE = (slice(30, 140), slice(50, 160))
|
||||
|
||||
|
||||
def load_references(table_json, icon_dir):
|
||||
data = json.load(open(table_json, encoding="utf-8"))
|
||||
refs = {}
|
||||
size = CROP_SIZE * 2 - BORDER * 2
|
||||
for item in data.get("items", {}).values():
|
||||
t = item.get("classifyType")
|
||||
if t not in {"NORMAL", "CONSUME", "MATERIAL"}:
|
||||
continue
|
||||
path = os.path.join(icon_dir, f"{item['iconId']}.png")
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
im = Image.open(path).resize((size, size)).crop((50, 30, 160, 140)).convert("L")
|
||||
refs[item["name"]] = np.array(im)
|
||||
print(f"已加载 {len(refs)} 个参考图,保存于 {REF_DIR}")
|
||||
return refs
|
||||
|
||||
|
||||
def process_circle(idx, circle, img, rois, size, dr):
|
||||
x, y, r = circle
|
||||
crop = img[
|
||||
max(0, y - CROP_SIZE) : min(img.shape[0], y + CROP_SIZE),
|
||||
max(0, x - CROP_SIZE) : min(img.shape[1], x + CROP_SIZE),
|
||||
]
|
||||
c = crop[BORDER:-BORDER, BORDER:-BORDER]
|
||||
sec = c[SECONDARY_SLICE[0], SECONDARY_SLICE[1]]
|
||||
gray_sec = cv2.cvtColor(sec, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# 保存 ROI 图像
|
||||
os.makedirs(ROI_DIR, exist_ok=True)
|
||||
roi_path = os.path.join(ROI_DIR, f"roi_{idx}.png")
|
||||
cv2.imwrite(roi_path, gray_sec)
|
||||
rois.append(gray_sec)
|
||||
return idx, gray_sec, roi_path
|
||||
|
||||
|
||||
def detect_and_crop(image_path):
|
||||
img = cv2.imread(image_path)
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, **HOUGH_PARAMS)
|
||||
if circles is None:
|
||||
print("未检测到圆形区域")
|
||||
return []
|
||||
circles = np.round(circles[0]).astype(int)
|
||||
rois = []
|
||||
size = CROP_SIZE * 2 - BORDER * 2
|
||||
dr = img.max() - img.min()
|
||||
|
||||
# 使用多进程加速
|
||||
with Pool() as pool:
|
||||
results = pool.starmap(
|
||||
process_circle,
|
||||
[(idx, circle, img, rois, size, dr) for idx, circle in enumerate(circles)],
|
||||
)
|
||||
|
||||
return [roi for idx, roi, path in results]
|
||||
|
||||
|
||||
def process_match(idx, roi, refs, thresh):
|
||||
best, score = "Unknown", -1
|
||||
dr = roi.max() - roi.min()
|
||||
for name, ref in refs.items():
|
||||
if roi.shape != ref.shape:
|
||||
continue
|
||||
s = structural_similarity(roi, ref, data_range=dr)
|
||||
if s > score:
|
||||
best, score = name, s
|
||||
if score >= thresh:
|
||||
return idx, best, score
|
||||
return idx, None, score
|
||||
|
||||
|
||||
def match_ssim(rois, refs, thresh=SSIM_THRESHOLD):
|
||||
from multiprocessing import Pool
|
||||
|
||||
args = [(idx, roi, refs, thresh) for idx, roi in enumerate(rois)]
|
||||
with Pool(processes=5) as pool:
|
||||
results = pool.starmap(process_match, args)
|
||||
|
||||
stats = {}
|
||||
match_idx = {}
|
||||
for idx, name, score in results:
|
||||
if name:
|
||||
stats[name] = stats.get(name, 0) + 1
|
||||
match_idx[idx] = name
|
||||
print(f"ROI {idx} 匹配结果: {name if name else 'Unknown'} (SSIM={score:.3f})")
|
||||
return stats, match_idx
|
||||
|
||||
|
||||
def display_matches(rois, match_idx, refs):
|
||||
ROW_LIMIT = 10 # 每行最多10个
|
||||
blank_ref = next(iter(refs.values())) # 取一个参考图的尺寸
|
||||
blank_img = Image.new(
|
||||
"RGB", (blank_ref.shape[1], blank_ref.shape[0]), (200, 200, 200)
|
||||
) # 灰色占位图
|
||||
|
||||
combined_images = []
|
||||
row_images = []
|
||||
|
||||
row_width = 0
|
||||
max_height = 0
|
||||
|
||||
for idx in range(len(rois)):
|
||||
roi_img = Image.fromarray(rois[idx]).convert("RGB")
|
||||
ref_name = match_idx.get(idx)
|
||||
if ref_name:
|
||||
ref_img = Image.fromarray(refs[ref_name]).convert("RGB")
|
||||
else:
|
||||
ref_img = blank_img.copy()
|
||||
|
||||
combined_width = roi_img.width + ref_img.width
|
||||
combined_height = max(roi_img.height, ref_img.height)
|
||||
|
||||
combined = Image.new("RGB", (combined_width, combined_height), (255, 255, 255))
|
||||
combined.paste(roi_img, (0, 0))
|
||||
combined.paste(ref_img, (roi_img.width, 0))
|
||||
|
||||
draw = ImageDraw.Draw(combined)
|
||||
|
||||
font = ImageFont.truetype("msyh.ttc", 20)
|
||||
|
||||
label = f"ROI {idx}: {ref_name if ref_name else 'Unknown'}"
|
||||
draw.text((5, 5), label, fill=(255, 0, 0), font=font)
|
||||
combined = ImageOps.expand(combined, border=2, fill=(0, 0, 0))
|
||||
row_images.append(combined)
|
||||
row_width += combined_width
|
||||
max_height = max(max_height, combined_height)
|
||||
|
||||
if len(row_images) == ROW_LIMIT or idx == len(rois) - 1:
|
||||
row_img = Image.new("RGB", (row_width, max_height), (255, 255, 255))
|
||||
x_offset = 0
|
||||
for img in row_images:
|
||||
row_img.paste(img, (x_offset, 0))
|
||||
x_offset += img.width
|
||||
combined_images.append(row_img)
|
||||
row_images = []
|
||||
row_width = 0
|
||||
max_height = 0
|
||||
|
||||
total_width = max(img.width for img in combined_images)
|
||||
total_height = sum(img.height for img in combined_images)
|
||||
|
||||
final_img = Image.new("RGB", (total_width, total_height), (255, 255, 255))
|
||||
y_offset = 0
|
||||
for img in combined_images:
|
||||
final_img.paste(img, (0, y_offset))
|
||||
y_offset += img.height
|
||||
|
||||
final_img.save("all_matches.png") # 或 final_img.save("all_matches.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
refs = load_references(REF_TABLE_JSON, ICON_DIR)
|
||||
rois = detect_and_crop(IMG_PATH)
|
||||
res, match_idx = match_ssim(rois, refs)
|
||||
|
||||
print("最终识别结果:", res)
|
||||
display_matches(rois, match_idx, refs)
|
||||
print(datetime.now() - now)
|
|
@ -1,309 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
||||
from skimage.metrics import structural_similarity as ssim
|
||||
|
||||
# 配置路径
|
||||
REF_DIR = r"depot_test\output/test/origin"
|
||||
ROI_DIR = r"depot_test\output/test/result"
|
||||
IMG_PATH = r"depot_test\output\result_template.png"
|
||||
REF_TABLE_JSON = "./ArknightsGameData/zh_CN/gamedata/excel/item_table.json"
|
||||
ICON_DIR = "./ArknightsResource/items/"
|
||||
|
||||
# 参数
|
||||
HOUGH_PARAMS = dict(
|
||||
dp=5, minDist=230, param1=50, param2=30, minRadius=90, maxRadius=100
|
||||
)
|
||||
CROP_SIZE = 130
|
||||
BORDER = 26
|
||||
SECONDARY_SLICE = (slice(30, 140), slice(50, 160))
|
||||
|
||||
|
||||
def load_references(table_json, icon_dir):
|
||||
data = json.load(open(table_json, encoding="utf-8"))
|
||||
refs = {}
|
||||
size = CROP_SIZE * 2 - BORDER * 2
|
||||
for item in data.get("items", {}).values():
|
||||
if item.get("classifyType") not in {"NORMAL", "CONSUME", "MATERIAL"}:
|
||||
continue
|
||||
path = os.path.join(icon_dir, f"{item['iconId']}.png")
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
im = Image.open(path).resize((size, size)).crop((50, 30, 160, 140)).convert("L")
|
||||
refs[item["name"]] = np.array(im)
|
||||
print(f"已加载 {len(refs)} 个参考图")
|
||||
return refs
|
||||
|
||||
|
||||
def process_circle(idx, circle, img):
|
||||
x, y, r = circle
|
||||
crop = img[
|
||||
max(0, y - CROP_SIZE) : min(img.shape[0], y + CROP_SIZE),
|
||||
max(0, x - CROP_SIZE) : min(img.shape[1], x + CROP_SIZE),
|
||||
]
|
||||
# Save color ROI
|
||||
os.makedirs(ROI_DIR, exist_ok=True)
|
||||
color_roi_path = os.path.join(ROI_DIR, f"color_roi_{idx}.png")
|
||||
cv2.imwrite(color_roi_path, crop[BORDER:-BORDER, BORDER:-BORDER])
|
||||
|
||||
c = crop[BORDER:-BORDER, BORDER:-BORDER]
|
||||
sec = c[SECONDARY_SLICE[0], SECONDARY_SLICE[1]]
|
||||
gray_sec = cv2.cvtColor(sec, cv2.COLOR_BGR2GRAY)
|
||||
roi_path = os.path.join(ROI_DIR, f"roi_{idx}.png")
|
||||
cv2.imwrite(roi_path, gray_sec)
|
||||
return idx, gray_sec, roi_path, color_roi_path
|
||||
|
||||
|
||||
def detect_and_crop(image_path):
|
||||
img = cv2.imread(image_path)
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, **HOUGH_PARAMS)
|
||||
if circles is None:
|
||||
print("未检测到圆形区域")
|
||||
return []
|
||||
|
||||
# Convert and sort circles by y then x coordinate
|
||||
circles = np.round(circles[0]).astype(int)
|
||||
circles = sorted(circles, key=lambda c: (c[0], c[1])) # Sort by y then x
|
||||
|
||||
results = []
|
||||
for idx, circle in enumerate(circles):
|
||||
result = process_circle(idx, circle, img)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def match_template(rois, refs, thresh):
|
||||
results = []
|
||||
for idx, roi, _, _ in rois:
|
||||
best, max_val = "Unknown", -1.0
|
||||
roi_f = roi.astype(np.float32) / 255.0
|
||||
for name, ref in refs.items():
|
||||
ref_f = ref.astype(np.float32) / 255.0
|
||||
if roi_f.shape != ref_f.shape:
|
||||
continue
|
||||
res = cv2.matchTemplate(roi_f, ref_f, cv2.TM_CCOEFF_NORMED)
|
||||
val = float(res.max())
|
||||
if val > max_val:
|
||||
best, max_val = name, val
|
||||
if max_val >= thresh:
|
||||
results.append((idx, best, max_val))
|
||||
else:
|
||||
results.append((idx, None, max_val))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def match_ssim(rois, refs, thresh):
|
||||
results = []
|
||||
|
||||
for idx, roi, _, _ in rois:
|
||||
best_match = "Unknown"
|
||||
max_combined_score = -1.0
|
||||
best_ssim = 0
|
||||
best_hist = 0
|
||||
best_edge = 0
|
||||
|
||||
# 预处理ROI
|
||||
roi_gray = roi if len(roi.shape) == 2 else cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
|
||||
roi_edges = cv2.Canny(roi_gray, 50, 150)
|
||||
roi_hist = cv2.calcHist([roi_gray], [0], None, [256], [0, 256])
|
||||
cv2.normalize(roi_hist, roi_hist, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)
|
||||
|
||||
for name, ref in refs.items():
|
||||
if roi_gray.shape != ref.shape:
|
||||
continue
|
||||
|
||||
# 1. 计算SSIM相似度
|
||||
ssim_score, _ = ssim(roi_gray, ref, full=True)
|
||||
|
||||
# 2. 计算直方图相似度
|
||||
ref_hist = cv2.calcHist([ref], [0], None, [256], [0, 256])
|
||||
cv2.normalize(
|
||||
ref_hist, ref_hist, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX
|
||||
)
|
||||
hist_score = cv2.compareHist(roi_hist, ref_hist, cv2.HISTCMP_CORREL)
|
||||
|
||||
# 3. 计算边缘相似度
|
||||
ref_edges = cv2.Canny(ref, 50, 150)
|
||||
edge_intersection = np.sum(roi_edges * ref_edges)
|
||||
edge_union = np.sum(roi_edges + ref_edges)
|
||||
edge_score = edge_intersection / edge_union if edge_union > 0 else 0
|
||||
|
||||
# 加权综合评分 (可调整权重)
|
||||
combined_score = 0.6 * ssim_score + 0.2 * hist_score + 0.2 * edge_score
|
||||
|
||||
if combined_score > max_combined_score:
|
||||
best_match = name
|
||||
max_combined_score = combined_score
|
||||
best_ssim = ssim_score
|
||||
best_hist = hist_score
|
||||
best_edge = edge_score
|
||||
|
||||
# 动态阈值调整 (基于图像复杂度)
|
||||
roi_complexity = np.std(roi_gray) / 255.0
|
||||
dynamic_thresh = thresh * (1 + 0.3 * roi_complexity)
|
||||
|
||||
if max_combined_score >= dynamic_thresh:
|
||||
results.append(
|
||||
(idx, best_match, max_combined_score, best_ssim, best_hist, best_edge)
|
||||
)
|
||||
else:
|
||||
results.append(
|
||||
(idx, None, max_combined_score, best_ssim, best_hist, best_edge)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def display_matches(rois, template_results, ssim_results, refs):
|
||||
ROW_LIMIT = 9
|
||||
blank_ref = next(iter(refs.values())) # 取一个参考图的尺寸
|
||||
blank_img = Image.new(
|
||||
"RGB", (blank_ref.shape[1], blank_ref.shape[0]), (200, 200, 200)
|
||||
) # 灰色占位图
|
||||
|
||||
combined_images = []
|
||||
row_images = []
|
||||
|
||||
row_width = 0
|
||||
max_height = 0
|
||||
|
||||
for item in rois:
|
||||
idx, _, _, color_path = item
|
||||
# Load color ROI image
|
||||
roi_img = Image.open(color_path).convert("RGB")
|
||||
|
||||
# Template matching result
|
||||
t_res = next((name for i, name, val in template_results if i == idx), None)
|
||||
t_val = next((val for i, name, val in template_results if i == idx), 0)
|
||||
if t_res is not None:
|
||||
t_ref_img = Image.fromarray(refs[t_res]).convert("RGB")
|
||||
else:
|
||||
t_ref_img = blank_img.copy()
|
||||
|
||||
# SSIM result - we need to get the detailed scores
|
||||
s_res = next((name for i, name, val in ssim_results if i == idx), None)
|
||||
s_val = next((val for i, name, val in ssim_results if i == idx), 0)
|
||||
if s_res is not None:
|
||||
s_ref_img = Image.fromarray(refs[s_res]).convert("RGB")
|
||||
else:
|
||||
s_ref_img = blank_img.copy()
|
||||
|
||||
# Combine Template Matching result (left) and SSIM result (right)
|
||||
combined_width = roi_img.width + t_ref_img.width + s_ref_img.width
|
||||
combined_height = max(roi_img.height, t_ref_img.height, s_ref_img.height)
|
||||
|
||||
combined = Image.new("RGB", (combined_width, combined_height), (255, 255, 255))
|
||||
combined.paste(roi_img, (0, 0))
|
||||
combined.paste(t_ref_img, (roi_img.width, 0))
|
||||
combined.paste(s_ref_img, (roi_img.width + t_ref_img.width, 0))
|
||||
|
||||
draw = ImageDraw.Draw(combined)
|
||||
font = ImageFont.truetype("msyh.ttc", 20)
|
||||
|
||||
# Get the detailed scores from the SSIM matching results
|
||||
ssim_details = next(
|
||||
(details for i, details in enumerate(ssim_results) if i == idx),
|
||||
(idx, None, 0, 0, 0, 0), # Default values if not found
|
||||
)
|
||||
best_ssim = ssim_details[3] if len(ssim_details) > 3 else 0
|
||||
best_hist = ssim_details[4] if len(ssim_details) > 4 else 0
|
||||
best_edge = ssim_details[5] if len(ssim_details) > 5 else 0
|
||||
|
||||
label = (
|
||||
f"ROI {idx} {best_ssim:.3f}{best_hist:.3f} {best_edge:.3f}\n"
|
||||
f"T({t_res if t_res else 'None'}, {t_val:.3f})\n"
|
||||
f"S({s_res if s_res else 'None'}, {s_val:.3f})"
|
||||
)
|
||||
|
||||
if t_res == s_res:
|
||||
draw.text(
|
||||
(roi_img.width, t_ref_img.height), label, fill=(255, 0, 0), font=font
|
||||
)
|
||||
else:
|
||||
draw.text(
|
||||
(roi_img.width, t_ref_img.height), label, fill=(255, 0, 255), font=font
|
||||
)
|
||||
combined = ImageOps.expand(combined, border=2, fill=(0, 0, 0))
|
||||
row_images.append(combined)
|
||||
row_width += combined_width
|
||||
max_height = max(max_height, combined_height)
|
||||
|
||||
if len(row_images) == ROW_LIMIT or idx == len(rois) - 1:
|
||||
row_img = Image.new("RGB", (row_width, max_height), (255, 255, 255))
|
||||
x_offset = 0
|
||||
for img in row_images:
|
||||
row_img.paste(img, (x_offset, 0))
|
||||
x_offset += img.width
|
||||
combined_images.append(row_img)
|
||||
row_images = []
|
||||
row_width = 0
|
||||
max_height = 0
|
||||
|
||||
total_width = max(img.width for img in combined_images)
|
||||
total_height = sum(img.height for img in combined_images)
|
||||
|
||||
final_img = Image.new("RGB", (total_width, total_height), (255, 255, 255))
|
||||
y_offset = 0
|
||||
for img in combined_images:
|
||||
final_img.paste(img, (0, y_offset))
|
||||
y_offset += img.height
|
||||
|
||||
final_img.save("depot_test/output/matches_all.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
refs = load_references(REF_TABLE_JSON, ICON_DIR)
|
||||
rois = detect_and_crop(IMG_PATH)
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
Template_MATCH_THRESHOLD = 0.2
|
||||
print("\n=== 使用 Template Matching ===")
|
||||
results_template = match_template(rois, refs, Template_MATCH_THRESHOLD)
|
||||
print("\n模版匹配耗时:", datetime.now() - now)
|
||||
SSIM_MATCH_THRESHOLD = 0.05
|
||||
|
||||
now = datetime.now()
|
||||
print("\n=== 使用 SSIM ===")
|
||||
results_ssim = match_ssim(rois, refs, SSIM_MATCH_THRESHOLD)
|
||||
print("\nSSIM匹配耗时:", datetime.now() - now)
|
||||
|
||||
print("\n=== 结果对比 ===")
|
||||
for idx, _, _, _ in rois:
|
||||
# Template matching results
|
||||
t_res = next((name for i, name, val in results_template if i == idx), None)
|
||||
t_val = next((val for i, name, val in results_template if i == idx), 0)
|
||||
|
||||
# SSIM results - now includes detailed metrics
|
||||
s_res = next(
|
||||
(name for i, name, val, ssim, hist, edge in results_ssim if i == idx), None
|
||||
)
|
||||
s_val = next(
|
||||
(val for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
||||
)
|
||||
s_ssim = next(
|
||||
(ssim for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
||||
)
|
||||
s_hist = next(
|
||||
(hist for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
||||
)
|
||||
s_edge = next(
|
||||
(edge for i, name, val, ssim, hist, edge in results_ssim if i == idx), 0
|
||||
)
|
||||
|
||||
print(
|
||||
f"ROI {idx}:\n"
|
||||
f" Template=({t_res if t_res else 'None'}, {t_val:.3f})\n"
|
||||
f" SSIM=({s_res if s_res else 'None'}, {s_val:.3f})\n"
|
||||
f" Details: SSIM={s_ssim:.3f}, Hist={s_hist:.3f}, Edge={s_edge:.3f}"
|
||||
)
|
||||
|
||||
# Displaying matches side by side using the updated function
|
||||
display_matches(rois, results_template, results_ssim, refs)
|
|
@ -1,192 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from multiprocessing import Pool
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
||||
|
||||
now = datetime.now()
|
||||
# 配置路径
|
||||
REF_DIR = r"depot_test\output/test/origin"
|
||||
ROI_DIR = r"depot_test\output/test/result"
|
||||
IMG_PATH = r"depot_test\stitched_image_multi.png"
|
||||
REF_TABLE_JSON = "./ArknightsGameData/zh_CN/gamedata/excel/item_table.json"
|
||||
ICON_DIR = "./ArknightsResource/items/"
|
||||
|
||||
# SSIM阈值
|
||||
MATCH_THRESHOLD = 0.01
|
||||
|
||||
# 圆检测参数
|
||||
HOUGH_PARAMS = dict(
|
||||
dp=5, minDist=230, param1=50, param2=30, minRadius=90, maxRadius=100
|
||||
)
|
||||
CROP_SIZE = 130
|
||||
BORDER = 26
|
||||
SECONDARY_SLICE = (slice(30, 140), slice(50, 160))
|
||||
|
||||
|
||||
def load_references(table_json, icon_dir):
|
||||
data = json.load(open(table_json, encoding="utf-8"))
|
||||
refs = {}
|
||||
size = CROP_SIZE * 2 - BORDER * 2
|
||||
for item in data.get("items", {}).values():
|
||||
t = item.get("classifyType")
|
||||
if t not in {"NORMAL", "CONSUME", "MATERIAL"}:
|
||||
continue
|
||||
path = os.path.join(icon_dir, f"{item['iconId']}.png")
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
im = Image.open(path).resize((size, size)).crop((50, 30, 160, 140)).convert("L")
|
||||
refs[item["name"]] = np.array(im)
|
||||
print(f"已加载 {len(refs)} 个参考图,保存于 {REF_DIR}")
|
||||
return refs
|
||||
|
||||
|
||||
def process_circle(idx, circle, img, rois, size, dr):
|
||||
x, y, r = circle
|
||||
crop = img[
|
||||
max(0, y - CROP_SIZE) : min(img.shape[0], y + CROP_SIZE),
|
||||
max(0, x - CROP_SIZE) : min(img.shape[1], x + CROP_SIZE),
|
||||
]
|
||||
c = crop[BORDER:-BORDER, BORDER:-BORDER]
|
||||
sec = c[SECONDARY_SLICE[0], SECONDARY_SLICE[1]]
|
||||
gray_sec = cv2.cvtColor(sec, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# 保存 ROI 图像
|
||||
os.makedirs(ROI_DIR, exist_ok=True)
|
||||
roi_path = os.path.join(ROI_DIR, f"roi_{idx}.png")
|
||||
cv2.imwrite(roi_path, gray_sec)
|
||||
rois.append(gray_sec)
|
||||
return idx, gray_sec, roi_path
|
||||
|
||||
|
||||
def detect_and_crop(image_path):
|
||||
img = cv2.imread(image_path)
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, **HOUGH_PARAMS)
|
||||
if circles is None:
|
||||
print("未检测到圆形区域")
|
||||
return []
|
||||
circles = np.round(circles[0]).astype(int)
|
||||
rois = []
|
||||
size = CROP_SIZE * 2 - BORDER * 2
|
||||
dr = img.max() - img.min()
|
||||
|
||||
# 使用多进程加速
|
||||
with Pool() as pool:
|
||||
results = pool.starmap(
|
||||
process_circle,
|
||||
[(idx, circle, img, rois, size, dr) for idx, circle in enumerate(circles)],
|
||||
)
|
||||
|
||||
return [roi for idx, roi, path in results]
|
||||
|
||||
|
||||
def process_match(idx, roi, refs, thresh):
|
||||
best, max_val = "Unknown", -1.0
|
||||
# 模板匹配需要浮点图
|
||||
roi_f = roi.astype(np.float32) / 255.0
|
||||
for name, ref in refs.items():
|
||||
# ref 已经是灰度 NumPy 数组
|
||||
ref_f = ref.astype(np.float32) / 255.0
|
||||
if roi_f.shape != ref_f.shape:
|
||||
continue
|
||||
# 使用 TM_CCOEFF_NORMED,结果越接近 1 越匹配
|
||||
res = cv2.matchTemplate(roi_f, ref_f, cv2.TM_CCOEFF_NORMED)
|
||||
val = float(res.max())
|
||||
if val > max_val:
|
||||
best, max_val = name, val
|
||||
if max_val >= thresh:
|
||||
return idx, best, max_val
|
||||
return idx, None, max_val
|
||||
|
||||
|
||||
def match_ssim(rois, refs, thresh=MATCH_THRESHOLD):
|
||||
from multiprocessing import Pool
|
||||
|
||||
args = [(idx, roi, refs, thresh) for idx, roi in enumerate(rois)]
|
||||
with Pool(processes=5) as pool:
|
||||
results = pool.starmap(process_match, args)
|
||||
|
||||
stats = {}
|
||||
match_idx = {}
|
||||
for idx, name, score in results:
|
||||
if name:
|
||||
stats[name] = stats.get(name, 0) + 1
|
||||
match_idx[idx] = name
|
||||
print(f"ROI {idx} 匹配结果: {name if name else 'Unknown'} (SSIM={score:.3f})")
|
||||
return stats, match_idx
|
||||
|
||||
|
||||
def display_matches(rois, match_idx, refs):
|
||||
ROW_LIMIT = 10 # 每行最多10个
|
||||
blank_ref = next(iter(refs.values())) # 取一个参考图的尺寸
|
||||
blank_img = Image.new(
|
||||
"RGB", (blank_ref.shape[1], blank_ref.shape[0]), (200, 200, 200)
|
||||
) # 灰色占位图
|
||||
|
||||
combined_images = []
|
||||
row_images = []
|
||||
|
||||
row_width = 0
|
||||
max_height = 0
|
||||
|
||||
for idx in range(len(rois)):
|
||||
roi_img = Image.fromarray(rois[idx]).convert("RGB")
|
||||
ref_name = match_idx.get(idx)
|
||||
if ref_name:
|
||||
ref_img = Image.fromarray(refs[ref_name]).convert("RGB")
|
||||
else:
|
||||
ref_img = blank_img.copy()
|
||||
|
||||
combined_width = roi_img.width + ref_img.width
|
||||
combined_height = max(roi_img.height, ref_img.height)
|
||||
|
||||
combined = Image.new("RGB", (combined_width, combined_height), (255, 255, 255))
|
||||
combined.paste(roi_img, (0, 0))
|
||||
combined.paste(ref_img, (roi_img.width, 0))
|
||||
|
||||
draw = ImageDraw.Draw(combined)
|
||||
|
||||
font = ImageFont.truetype("msyh.ttc", 20)
|
||||
|
||||
label = f"ROI {idx}: {ref_name if ref_name else 'Unknown'}"
|
||||
draw.text((5, 5), label, fill=(255, 0, 0), font=font)
|
||||
combined = ImageOps.expand(combined, border=2, fill=(0, 0, 0))
|
||||
row_images.append(combined)
|
||||
row_width += combined_width
|
||||
max_height = max(max_height, combined_height)
|
||||
|
||||
if len(row_images) == ROW_LIMIT or idx == len(rois) - 1:
|
||||
row_img = Image.new("RGB", (row_width, max_height), (255, 255, 255))
|
||||
x_offset = 0
|
||||
for img in row_images:
|
||||
row_img.paste(img, (x_offset, 0))
|
||||
x_offset += img.width
|
||||
combined_images.append(row_img)
|
||||
row_images = []
|
||||
row_width = 0
|
||||
max_height = 0
|
||||
|
||||
total_width = max(img.width for img in combined_images)
|
||||
total_height = sum(img.height for img in combined_images)
|
||||
|
||||
final_img = Image.new("RGB", (total_width, total_height), (255, 255, 255))
|
||||
y_offset = 0
|
||||
for img in combined_images:
|
||||
final_img.paste(img, (0, y_offset))
|
||||
y_offset += img.height
|
||||
|
||||
final_img.save("all_matches.png") # 或 final_img.save("all_matches.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
refs = load_references(REF_TABLE_JSON, ICON_DIR)
|
||||
rois = detect_and_crop(IMG_PATH)
|
||||
res, match_idx = match_ssim(rois, refs)
|
||||
|
||||
print("最终识别结果:", res)
|
||||
display_matches(rois, match_idx, refs)
|
||||
print(datetime.now() - now)
|
130
depot_test/训练.py
Normal file
130
depot_test/训练.py
Normal file
|
@ -0,0 +1,130 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
# 1. 网络定义
|
||||
class SiameseNetwork(nn.Module):
|
||||
def __init__(self):
|
||||
super(SiameseNetwork, self).__init__()
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv2d(3, 32, kernel_size=5), # (3,110,110) -> (32,106,106)
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2), # -> (32,53,53)
|
||||
nn.Conv2d(32, 64, kernel_size=5), # -> (64,49,49)
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2), # -> (64,24,24)
|
||||
)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(64 * 24 * 24, 512), nn.ReLU(), nn.Linear(512, 128)
|
||||
)
|
||||
|
||||
def forward_once(self, x):
|
||||
x = self.cnn(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.fc(x)
|
||||
|
||||
def forward(self, input1, input2):
|
||||
output1 = self.forward_once(input1)
|
||||
output2 = self.forward_once(input2)
|
||||
return output1, output2
|
||||
|
||||
|
||||
# 2. Contrastive Loss
|
||||
class ContrastiveLoss(nn.Module):
|
||||
def __init__(self, margin=2.0):
|
||||
super(ContrastiveLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, out1, out2, label):
|
||||
dist = F.pairwise_distance(out1, out2)
|
||||
loss = label * torch.pow(dist, 2) + (1 - label) * torch.pow(
|
||||
torch.clamp(self.margin - dist, min=0.0), 2
|
||||
)
|
||||
return loss.mean()
|
||||
|
||||
|
||||
# 3. Dataset 构造器
|
||||
class SiameseDataset(Dataset):
|
||||
def __init__(self, folder_path, transform=None):
|
||||
self.folder_path = folder_path
|
||||
self.classes = os.listdir(folder_path)
|
||||
self.transform = transform or transforms.ToTensor()
|
||||
|
||||
def __getitem__(self, index):
|
||||
class1 = random.choice(self.classes)
|
||||
class2 = (
|
||||
class1
|
||||
if random.random() < 0.5
|
||||
else random.choice([c for c in self.classes if c != class1])
|
||||
)
|
||||
|
||||
img1_path = os.path.join(
|
||||
self.folder_path,
|
||||
class1,
|
||||
random.choice(os.listdir(os.path.join(self.folder_path, class1))),
|
||||
)
|
||||
img2_path = os.path.join(
|
||||
self.folder_path,
|
||||
class2,
|
||||
random.choice(os.listdir(os.path.join(self.folder_path, class2))),
|
||||
)
|
||||
|
||||
img1 = Image.open(img1_path).convert("RGB").resize((110, 110))
|
||||
img2 = Image.open(img2_path).convert("RGB").resize((110, 110))
|
||||
|
||||
if self.transform:
|
||||
img1 = self.transform(img1)
|
||||
img2 = self.transform(img2)
|
||||
|
||||
label = 1.0 if class1 == class2 else 0.0
|
||||
return img1, img2, torch.tensor([label], dtype=torch.float32)
|
||||
|
||||
def __len__(self):
|
||||
return 5000
|
||||
|
||||
|
||||
# 4. 训练入口
|
||||
def train():
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(device)
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
dataset = SiameseDataset("dataset/train", transform=transform)
|
||||
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
|
||||
|
||||
model = SiameseNetwork().to(device)
|
||||
criterion = ContrastiveLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
|
||||
for epoch in range(100):
|
||||
total_loss = 0
|
||||
for img1, img2, label in tqdm(dataloader, desc=f"Epoch {epoch + 1}"):
|
||||
img1, img2, label = img1.to(device), img2.to(device), label.to(device)
|
||||
|
||||
out1, out2 = model(img1, img2)
|
||||
loss = criterion(out1, out2, label)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.item()
|
||||
|
||||
print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader):.4f}")
|
||||
|
||||
torch.save(model.state_dict(), "siamese_model.pth")
|
||||
print("Model saved as siamese_model.pth")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
74
depot_test/预测.py
Normal file
74
depot_test/预测.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from 训练 import SiameseNetwork # 确保和训练脚本在同一目录
|
||||
|
||||
|
||||
# 加载模型
|
||||
def load_model(model_path, device):
|
||||
model = SiameseNetwork().to(device)
|
||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
# 读取训练集中每个类别的图像(作为支持集)
|
||||
def load_train_class_images(train_dir, transform, device):
|
||||
class_images = {}
|
||||
for class_name in os.listdir(train_dir):
|
||||
class_path = os.path.join(train_dir, class_name)
|
||||
if not os.path.isdir(class_path):
|
||||
continue
|
||||
images = []
|
||||
for img_name in os.listdir(class_path):
|
||||
img_path = os.path.join(class_path, img_name)
|
||||
image = Image.open(img_path).convert("RGB").resize((110, 110))
|
||||
image = transform(image).unsqueeze(0).to(device) # shape: (1,3,110,110)
|
||||
images.append(image)
|
||||
if images:
|
||||
class_images[class_name] = images
|
||||
return class_images
|
||||
|
||||
# 推理函数:返回最相似的类别名
|
||||
def predict(model, test_img, class_images):
|
||||
min_dist = float("inf")
|
||||
predicted_class = None
|
||||
|
||||
with torch.no_grad():
|
||||
for class_name, ref_images in class_images.items():
|
||||
for ref_img in ref_images:
|
||||
out1, out2 = model(test_img, ref_img)
|
||||
dist = F.pairwise_distance(out1, out2)
|
||||
if dist.item() < min_dist:
|
||||
min_dist = dist.item()
|
||||
predicted_class = class_name
|
||||
return predicted_class, min_dist
|
||||
|
||||
# 主推理流程
|
||||
def infer():
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = load_model("siamese_model.pth", device)
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
train_dir = "dataset/train"
|
||||
test_dir = "dataset/test"
|
||||
class_images = load_train_class_images(train_dir, transform, device)
|
||||
|
||||
print("开始测试...")
|
||||
for class_name in os.listdir(test_dir):
|
||||
class_path = os.path.join(test_dir, class_name)
|
||||
for img_name in os.listdir(class_path):
|
||||
img_path = os.path.join(class_path, img_name)
|
||||
img = Image.open(img_path).convert("RGB").resize((110, 110))
|
||||
img_tensor = transform(img).unsqueeze(0).to(device)
|
||||
|
||||
predicted_class, dist = predict(model, img_tensor, class_images)
|
||||
print(f"Test Image: {img_name} | True: {class_name} | Predicted: {predicted_class} | Distance: {dist:.4f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
infer()
|
Loading…
Add table
Add a link
Reference in a new issue