mower-ng/depot_test/预测.py
zhbaor eb149b0fbf
All checks were successful
ci/woodpecker/push/check_format Pipeline was successful
自动格式化代码
2025-05-24 11:56:25 +08:00

82 lines
2.6 KiB
Python

import os
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from 训练 import SiameseNetwork # 确保和训练脚本在同一目录
# 加载模型
def load_model(model_path, device):
model = SiameseNetwork().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model
# 读取训练集中每个类别的图像(作为支持集)
def load_train_class_images(train_dir, transform, device):
class_images = {}
for class_name in os.listdir(train_dir):
class_path = os.path.join(train_dir, class_name)
if not os.path.isdir(class_path):
continue
images = []
for img_name in os.listdir(class_path):
img_path = os.path.join(class_path, img_name)
image = Image.open(img_path).convert("RGB").resize((110, 110))
image = transform(image).unsqueeze(0).to(device) # shape: (1,3,110,110)
images.append(image)
if images:
class_images[class_name] = images
return class_images
# 推理函数:返回最相似的类别名
def predict(model, test_img, class_images):
min_dist = float("inf")
predicted_class = None
with torch.no_grad():
for class_name, ref_images in class_images.items():
for ref_img in ref_images:
out1, out2 = model(test_img, ref_img)
dist = F.pairwise_distance(out1, out2)
if dist.item() < min_dist:
min_dist = dist.item()
predicted_class = class_name
return predicted_class, min_dist
# 主推理流程
def infer():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model("siamese_model.pth", device)
transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
train_dir = "dataset/train"
test_dir = "dataset/test"
class_images = load_train_class_images(train_dir, transform, device)
print("开始测试...")
for class_name in os.listdir(test_dir):
class_path = os.path.join(test_dir, class_name)
for img_name in os.listdir(class_path):
img_path = os.path.join(class_path, img_name)
img = Image.open(img_path).convert("RGB").resize((110, 110))
img_tensor = transform(img).unsqueeze(0).to(device)
predicted_class, dist = predict(model, img_tensor, class_images)
print(
f"Test Image: {img_name} | True: {class_name} | Predicted: {predicted_class} | Distance: {dist:.4f}"
)
if __name__ == "__main__":
infer()