All checks were successful
ci/woodpecker/push/check_format Pipeline was successful
82 lines
2.6 KiB
Python
82 lines
2.6 KiB
Python
import os
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
from 训练 import SiameseNetwork # 确保和训练脚本在同一目录
|
|
|
|
|
|
# 加载模型
|
|
def load_model(model_path, device):
|
|
model = SiameseNetwork().to(device)
|
|
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
# 读取训练集中每个类别的图像(作为支持集)
|
|
def load_train_class_images(train_dir, transform, device):
|
|
class_images = {}
|
|
for class_name in os.listdir(train_dir):
|
|
class_path = os.path.join(train_dir, class_name)
|
|
if not os.path.isdir(class_path):
|
|
continue
|
|
images = []
|
|
for img_name in os.listdir(class_path):
|
|
img_path = os.path.join(class_path, img_name)
|
|
image = Image.open(img_path).convert("RGB").resize((110, 110))
|
|
image = transform(image).unsqueeze(0).to(device) # shape: (1,3,110,110)
|
|
images.append(image)
|
|
if images:
|
|
class_images[class_name] = images
|
|
return class_images
|
|
|
|
|
|
# 推理函数:返回最相似的类别名
|
|
def predict(model, test_img, class_images):
|
|
min_dist = float("inf")
|
|
predicted_class = None
|
|
|
|
with torch.no_grad():
|
|
for class_name, ref_images in class_images.items():
|
|
for ref_img in ref_images:
|
|
out1, out2 = model(test_img, ref_img)
|
|
dist = F.pairwise_distance(out1, out2)
|
|
if dist.item() < min_dist:
|
|
min_dist = dist.item()
|
|
predicted_class = class_name
|
|
return predicted_class, min_dist
|
|
|
|
|
|
# 主推理流程
|
|
def infer():
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = load_model("siamese_model.pth", device)
|
|
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.ToTensor(),
|
|
]
|
|
)
|
|
|
|
train_dir = "dataset/train"
|
|
test_dir = "dataset/test"
|
|
class_images = load_train_class_images(train_dir, transform, device)
|
|
|
|
print("开始测试...")
|
|
for class_name in os.listdir(test_dir):
|
|
class_path = os.path.join(test_dir, class_name)
|
|
for img_name in os.listdir(class_path):
|
|
img_path = os.path.join(class_path, img_name)
|
|
img = Image.open(img_path).convert("RGB").resize((110, 110))
|
|
img_tensor = transform(img).unsqueeze(0).to(device)
|
|
|
|
predicted_class, dist = predict(model, img_tensor, class_images)
|
|
print(
|
|
f"Test Image: {img_name} | True: {class_name} | Predicted: {predicted_class} | Distance: {dist:.4f}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
infer()
|