mower-ng/depot_test/仓库识别MobileNetV3 KNN copy.py
2025-05-23 19:36:44 +08:00

324 lines
11 KiB
Python

import glob
import os
import re # 用于解析文件名中的标签
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm # 用于显示进度条
# --- 配置参数 ---
TRAIN_DIR = "训练集"
VAL_TEST_DIR = "测试集" # 根据你的描述,验证集和测试集是同一个目录,文件格式相同
IMAGE_SIZE = 224 # MobileNetV3 的标准输入大小
BATCH_SIZE = 32
NUM_EPOCHS = 20 # 可以根据需要调整
LEARNING_RATE = 0.001 # 初始学习率
SAVE_MODEL_PATH = "mobilenetv3_small_finetuned.pth"
# 自动检测设备
device = torch.device("cuda")
print(f"使用设备: {device}")
# --- 数据预处理和增强 ---
# ImageNet 标准均值和标准差
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# 训练集数据增强和预处理
train_transforms = transforms.Compose(
[
transforms.RandomResizedCrop(IMAGE_SIZE), # 随机裁剪并缩放
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(), # 转换为 Tensor
transforms.Normalize(mean, std), # 标准化
]
)
# 验证/测试集数据预处理 (不需要数据增强,只需要中心裁剪和标准化)
val_test_transforms = transforms.Compose(
[
transforms.Resize(int(IMAGE_SIZE * 256 / 224)), # 缩放到较大尺寸
transforms.CenterCrop(IMAGE_SIZE), # 中心裁剪到目标尺寸
transforms.ToTensor(), # 转换为 Tensor
transforms.Normalize(mean, std), # 标准化
]
)
# --- 自定义测试集 Dataset ---
# 需要一个自定义 Dataset 来处理 "测试集/{idx}_{label}.png" 这种文件命名格式
class CustomValTestDataset(Dataset):
def __init__(self, root_dir, class_to_idx, transform=None):
"""
Args:
root_dir (string): 数据集根目录 (e.g., "测试集").
class_to_idx (dict): 从类别名称到索引的映射,与训练集一致。
transform (callable, optional): 应用于图像的转换.
"""
self.root_dir = root_dir
self.transform = transform
self.class_to_idx = class_to_idx
self.idx_to_class = {v: k for k, v in class_to_idx.items()}
self.image_files = []
self.labels = []
# 遍历目录下的所有png文件
filepaths = glob.glob(os.path.join(root_dir, "*.png"))
# 解析文件名,提取标签
pattern = re.compile(r"^\d+_([^_]+)\.png$") # 匹配 数字_标签.png
for filepath in filepaths:
filename = os.path.basename(filepath)
match = pattern.match(filename)
if match:
label_name = match.group(1)
if label_name in self.class_to_idx:
self.image_files.append(filepath)
self.labels.append(self.class_to_idx[label_name])
else:
print(
f"警告: 文件 '{filename}' 中的标签 '{label_name}' 不在训练集的类别中,将跳过。"
)
print(f"加载了 {len(self.image_files)} 张验证/测试图片。")
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_path = self.image_files[idx]
label = self.labels[idx]
# 打开图像,确保是 RGB (处理可能的灰度图)
img = Image.open(img_path).convert("RGB")
if self.transform:
img = self.transform(img)
return img, label
# --- 加载数据 ---
# 使用 ImageFolder 加载训练集,它会自动从目录名解析类别
if not os.path.exists(TRAIN_DIR):
print(
f"错误: 训练集目录 '{TRAIN_DIR}' 不存在。请创建该目录并放入分类好的图片子目录。"
)
exit()
if not os.path.exists(VAL_TEST_DIR):
print(f"错误: 验证/测试集目录 '{VAL_TEST_DIR}' 不存在。请创建该目录并放入图片。")
exit()
train_dataset = ImageFolder(TRAIN_DIR, transform=train_transforms)
num_classes = len(train_dataset.classes)
class_to_idx = train_dataset.class_to_idx # 获取类别到索引的映射
print(f"从训练集检测到 {num_classes} 个类别: {train_dataset.classes}")
# 使用自定义 Dataset 加载验证/测试集
val_test_dataset = CustomValTestDataset(
VAL_TEST_DIR, class_to_idx, transform=val_test_transforms
)
# 创建 DataLoader
train_loader = DataLoader(
train_dataset, batch_size=BATCH_SIZE, shuffle=True
) # num_workers 根据你的机器性能调整
val_test_loader = DataLoader(val_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
# --- 加载预训练的 MobileNetV3-Small 模型 ---
# 使用 weights 参数来指定预训练权重
# MobileNetV3_Small_Weights.IMAGENET1K_V1 是在 ImageNet 上预训练的权重
try:
model = models.mobilenet_v3_small(
weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
)
print("成功加载预训练的 MobileNetV3-Small 模型 (ImageNet weights)。")
except Exception as e:
print(f"加载预训练模型失败: {e}")
print("尝试加载不带权重的模型...")
model = models.mobilenet_v3_small(weights=None)
# --- 修改全连接层以匹配新的类别数量 ---
# MobileNetV3 的分类器是 model.classifier
# 最后一个线性层是 classifier[-1]
num_ftrs = model.classifier[-1].in_features
# 替换掉原来的全连接层
model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
model = model.to(device)
# --- 定义损失函数和优化器 ---
criterion = nn.CrossEntropyLoss() # 交叉熵损失适用于分类问题
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) # Adam 优化器
# 可选:学习率调度器,帮助调整学习率
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) # 每7个epoch降低学习率
# --- 训练和评估函数 ---
def train_epoch(model, train_loader, criterion, optimizer, device):
model.train() # 设置模型为训练模式
running_loss = 0.0
correct_predictions = 0
total_samples = 0
# 使用 tqdm 显示进度条
for inputs, labels in tqdm(train_loader, desc="训练中"):
inputs, labels = inputs.to(device), labels.to(device)
# 梯度清零
optimizer.zero_grad()
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
loss.backward()
optimizer.step()
# 统计
running_loss += loss.item() * inputs.size(0) # 累加 batch loss * batch size
_, predicted = torch.max(outputs, 1) # 获取预测结果
correct_predictions += (predicted == labels).sum().item()
total_samples += labels.size(0)
epoch_loss = running_loss / total_samples
epoch_accuracy = correct_predictions / total_samples
return epoch_loss, epoch_accuracy
def evaluate(model, data_loader, criterion, device, desc="评估中"):
model.eval() # 设置模型为评估模式
running_loss = 0.0
correct_predictions = 0
total_samples = 0
# 在评估阶段不计算梯度
with torch.no_grad():
# 使用 tqdm 显示进度条
for inputs, labels in tqdm(data_loader, desc=desc):
inputs, labels = inputs.to(device), labels.to(device)
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 统计
running_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
correct_predictions += (predicted == labels).sum().item()
total_samples += labels.size(0)
epoch_loss = running_loss / total_samples
epoch_accuracy = correct_predictions / total_samples
return epoch_loss, epoch_accuracy
# --- 训练循环 ---
best_val_accuracy = 0.0
print("\n开始训练...")
for epoch in range(NUM_EPOCHS):
print(f"\n--- Epoch {epoch + 1}/{NUM_EPOCHS} ---")
# 训练阶段
train_loss, train_accuracy = train_epoch(
model, train_loader, criterion, optimizer, device
)
print(
f"Epoch {epoch + 1} 训练 Loss: {train_loss:.4f}, 准确率: {train_accuracy:.4f}"
)
# 可选:学习率调度
# if scheduler is not None:
# scheduler.step()
# 验证/测试阶段
val_loss, val_accuracy = evaluate(
model, val_test_loader, criterion, device, desc="验证/测试中"
)
print(
f"Epoch {epoch + 1} 验证/测试 Loss: {val_loss:.4f}, 准确率: {val_accuracy:.4f}"
)
# 保存最优模型
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
torch.save(model.state_dict(), SAVE_MODEL_PATH)
print(
f"保存了验证/测试集上最优的模型 (准确率: {best_val_accuracy:.4f}) 到 {SAVE_MODEL_PATH}"
)
print("\n训练完成!")
print(f"在验证/测试集上的最高准确率: {best_val_accuracy:.4f}")
# --- 可选: 加载并测试最优模型 ---
# 加载保存的最优模型进行最终测试 (如果验证/测试集是同一个,这就是最终结果)
print(f"\n加载最优模型 '{SAVE_MODEL_PATH}' 进行最终评估...")
loaded_model = models.mobilenet_v3_small(weights=None) # 先加载一个空的模型结构
loaded_model.classifier[-1] = nn.Linear(
loaded_model.classifier[-1].in_features, num_classes
) # 修改分类器
loaded_model.load_state_dict(
torch.load(SAVE_MODEL_PATH, map_location=device)
) # 加载权重
loaded_model = loaded_model.to(device)
final_test_loss, final_test_accuracy = evaluate(
loaded_model, val_test_loader, criterion, device, desc="最终测试中"
)
print(
f"\n最终测试 Loss: {final_test_loss:.4f}, 最终测试准确率: {final_test_accuracy:.4f}"
)
# --- 可选: 预测单张图片 ---
# 假设你想预测一张名为 '测试集/some_image_X_label.png' 的图片
def predict_single_image(image_path, model, class_to_idx, device, transform):
model.eval()
idx_to_class = {v: k for k, v in class_to_idx.items()}
try:
img = Image.open(image_path).convert("RGB")
img = transform(img).unsqueeze(0).to(device) # 添加 batch 维度并移动到设备
with torch.no_grad():
outputs = model(img)
probabilities = torch.softmax(outputs, dim=1)[0] # 获取概率分布
_, predicted_idx = torch.max(probabilities, 0) # 获取最高概率的索引
predicted_label = idx_to_class[predicted_idx.item()]
confidence = probabilities[predicted_idx].item()
print(f"\n预测图片: {image_path}")
print(f"预测类别: {predicted_label}, 置信度: {confidence:.4f}")
return predicted_label, confidence
except FileNotFoundError:
print(f"错误: 图片文件 '{image_path}' 未找到。")
return None, None
except Exception as e:
print(f"预测图片时发生错误: {e}")
return None, None
# # 示例预测 (取消注释以使用)
example_image_path = r"测试集\27_基础作战记录.png" # 替换为你测试集中的实际文件路径
predict_single_image(
example_image_path, loaded_model, class_to_idx, device, val_test_transforms
)