324 lines
11 KiB
Python
324 lines
11 KiB
Python
import glob
|
|
import os
|
|
import re # 用于解析文件名中的标签
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from PIL import Image
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from torchvision import models, transforms
|
|
from torchvision.datasets import ImageFolder
|
|
from tqdm import tqdm # 用于显示进度条
|
|
|
|
# --- 配置参数 ---
|
|
TRAIN_DIR = "训练集"
|
|
VAL_TEST_DIR = "测试集" # 根据你的描述,验证集和测试集是同一个目录,文件格式相同
|
|
IMAGE_SIZE = 224 # MobileNetV3 的标准输入大小
|
|
BATCH_SIZE = 32
|
|
NUM_EPOCHS = 20 # 可以根据需要调整
|
|
LEARNING_RATE = 0.001 # 初始学习率
|
|
SAVE_MODEL_PATH = "mobilenetv3_small_finetuned.pth"
|
|
|
|
# 自动检测设备
|
|
device = torch.device("cuda")
|
|
print(f"使用设备: {device}")
|
|
|
|
# --- 数据预处理和增强 ---
|
|
# ImageNet 标准均值和标准差
|
|
mean = [0.485, 0.456, 0.406]
|
|
std = [0.229, 0.224, 0.225]
|
|
|
|
# 训练集数据增强和预处理
|
|
train_transforms = transforms.Compose(
|
|
[
|
|
transforms.RandomResizedCrop(IMAGE_SIZE), # 随机裁剪并缩放
|
|
transforms.RandomHorizontalFlip(), # 随机水平翻转
|
|
transforms.ToTensor(), # 转换为 Tensor
|
|
transforms.Normalize(mean, std), # 标准化
|
|
]
|
|
)
|
|
|
|
# 验证/测试集数据预处理 (不需要数据增强,只需要中心裁剪和标准化)
|
|
val_test_transforms = transforms.Compose(
|
|
[
|
|
transforms.Resize(int(IMAGE_SIZE * 256 / 224)), # 缩放到较大尺寸
|
|
transforms.CenterCrop(IMAGE_SIZE), # 中心裁剪到目标尺寸
|
|
transforms.ToTensor(), # 转换为 Tensor
|
|
transforms.Normalize(mean, std), # 标准化
|
|
]
|
|
)
|
|
|
|
|
|
# --- 自定义测试集 Dataset ---
|
|
# 需要一个自定义 Dataset 来处理 "测试集/{idx}_{label}.png" 这种文件命名格式
|
|
class CustomValTestDataset(Dataset):
|
|
def __init__(self, root_dir, class_to_idx, transform=None):
|
|
"""
|
|
Args:
|
|
root_dir (string): 数据集根目录 (e.g., "测试集").
|
|
class_to_idx (dict): 从类别名称到索引的映射,与训练集一致。
|
|
transform (callable, optional): 应用于图像的转换.
|
|
"""
|
|
self.root_dir = root_dir
|
|
self.transform = transform
|
|
self.class_to_idx = class_to_idx
|
|
self.idx_to_class = {v: k for k, v in class_to_idx.items()}
|
|
self.image_files = []
|
|
self.labels = []
|
|
|
|
# 遍历目录下的所有png文件
|
|
filepaths = glob.glob(os.path.join(root_dir, "*.png"))
|
|
|
|
# 解析文件名,提取标签
|
|
pattern = re.compile(r"^\d+_([^_]+)\.png$") # 匹配 数字_标签.png
|
|
|
|
for filepath in filepaths:
|
|
filename = os.path.basename(filepath)
|
|
match = pattern.match(filename)
|
|
if match:
|
|
label_name = match.group(1)
|
|
if label_name in self.class_to_idx:
|
|
self.image_files.append(filepath)
|
|
self.labels.append(self.class_to_idx[label_name])
|
|
else:
|
|
print(
|
|
f"警告: 文件 '{filename}' 中的标签 '{label_name}' 不在训练集的类别中,将跳过。"
|
|
)
|
|
|
|
print(f"加载了 {len(self.image_files)} 张验证/测试图片。")
|
|
|
|
def __len__(self):
|
|
return len(self.image_files)
|
|
|
|
def __getitem__(self, idx):
|
|
if torch.is_tensor(idx):
|
|
idx = idx.tolist()
|
|
|
|
img_path = self.image_files[idx]
|
|
label = self.labels[idx]
|
|
|
|
# 打开图像,确保是 RGB (处理可能的灰度图)
|
|
img = Image.open(img_path).convert("RGB")
|
|
|
|
if self.transform:
|
|
img = self.transform(img)
|
|
|
|
return img, label
|
|
|
|
|
|
# --- 加载数据 ---
|
|
# 使用 ImageFolder 加载训练集,它会自动从目录名解析类别
|
|
if not os.path.exists(TRAIN_DIR):
|
|
print(
|
|
f"错误: 训练集目录 '{TRAIN_DIR}' 不存在。请创建该目录并放入分类好的图片子目录。"
|
|
)
|
|
exit()
|
|
|
|
if not os.path.exists(VAL_TEST_DIR):
|
|
print(f"错误: 验证/测试集目录 '{VAL_TEST_DIR}' 不存在。请创建该目录并放入图片。")
|
|
exit()
|
|
|
|
|
|
train_dataset = ImageFolder(TRAIN_DIR, transform=train_transforms)
|
|
num_classes = len(train_dataset.classes)
|
|
class_to_idx = train_dataset.class_to_idx # 获取类别到索引的映射
|
|
|
|
print(f"从训练集检测到 {num_classes} 个类别: {train_dataset.classes}")
|
|
|
|
# 使用自定义 Dataset 加载验证/测试集
|
|
val_test_dataset = CustomValTestDataset(
|
|
VAL_TEST_DIR, class_to_idx, transform=val_test_transforms
|
|
)
|
|
|
|
# 创建 DataLoader
|
|
train_loader = DataLoader(
|
|
train_dataset, batch_size=BATCH_SIZE, shuffle=True
|
|
) # num_workers 根据你的机器性能调整
|
|
val_test_loader = DataLoader(val_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
|
|
|
# --- 加载预训练的 MobileNetV3-Small 模型 ---
|
|
# 使用 weights 参数来指定预训练权重
|
|
# MobileNetV3_Small_Weights.IMAGENET1K_V1 是在 ImageNet 上预训练的权重
|
|
try:
|
|
model = models.mobilenet_v3_small(
|
|
weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
|
|
)
|
|
print("成功加载预训练的 MobileNetV3-Small 模型 (ImageNet weights)。")
|
|
except Exception as e:
|
|
print(f"加载预训练模型失败: {e}")
|
|
print("尝试加载不带权重的模型...")
|
|
model = models.mobilenet_v3_small(weights=None)
|
|
|
|
|
|
# --- 修改全连接层以匹配新的类别数量 ---
|
|
# MobileNetV3 的分类器是 model.classifier
|
|
# 最后一个线性层是 classifier[-1]
|
|
num_ftrs = model.classifier[-1].in_features
|
|
# 替换掉原来的全连接层
|
|
model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
|
|
|
|
model = model.to(device)
|
|
|
|
# --- 定义损失函数和优化器 ---
|
|
criterion = nn.CrossEntropyLoss() # 交叉熵损失适用于分类问题
|
|
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) # Adam 优化器
|
|
|
|
# 可选:学习率调度器,帮助调整学习率
|
|
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) # 每7个epoch降低学习率
|
|
|
|
# --- 训练和评估函数 ---
|
|
|
|
|
|
def train_epoch(model, train_loader, criterion, optimizer, device):
|
|
model.train() # 设置模型为训练模式
|
|
running_loss = 0.0
|
|
correct_predictions = 0
|
|
total_samples = 0
|
|
|
|
# 使用 tqdm 显示进度条
|
|
for inputs, labels in tqdm(train_loader, desc="训练中"):
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
|
|
# 梯度清零
|
|
optimizer.zero_grad()
|
|
|
|
# 前向传播
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs, labels)
|
|
|
|
# 反向传播和优化
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# 统计
|
|
running_loss += loss.item() * inputs.size(0) # 累加 batch loss * batch size
|
|
_, predicted = torch.max(outputs, 1) # 获取预测结果
|
|
correct_predictions += (predicted == labels).sum().item()
|
|
total_samples += labels.size(0)
|
|
|
|
epoch_loss = running_loss / total_samples
|
|
epoch_accuracy = correct_predictions / total_samples
|
|
return epoch_loss, epoch_accuracy
|
|
|
|
|
|
def evaluate(model, data_loader, criterion, device, desc="评估中"):
|
|
model.eval() # 设置模型为评估模式
|
|
running_loss = 0.0
|
|
correct_predictions = 0
|
|
total_samples = 0
|
|
|
|
# 在评估阶段不计算梯度
|
|
with torch.no_grad():
|
|
# 使用 tqdm 显示进度条
|
|
for inputs, labels in tqdm(data_loader, desc=desc):
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
|
|
# 前向传播
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs, labels)
|
|
|
|
# 统计
|
|
running_loss += loss.item() * inputs.size(0)
|
|
_, predicted = torch.max(outputs, 1)
|
|
correct_predictions += (predicted == labels).sum().item()
|
|
total_samples += labels.size(0)
|
|
|
|
epoch_loss = running_loss / total_samples
|
|
epoch_accuracy = correct_predictions / total_samples
|
|
return epoch_loss, epoch_accuracy
|
|
|
|
|
|
# --- 训练循环 ---
|
|
best_val_accuracy = 0.0
|
|
|
|
print("\n开始训练...")
|
|
for epoch in range(NUM_EPOCHS):
|
|
print(f"\n--- Epoch {epoch + 1}/{NUM_EPOCHS} ---")
|
|
|
|
# 训练阶段
|
|
train_loss, train_accuracy = train_epoch(
|
|
model, train_loader, criterion, optimizer, device
|
|
)
|
|
print(
|
|
f"Epoch {epoch + 1} 训练 Loss: {train_loss:.4f}, 准确率: {train_accuracy:.4f}"
|
|
)
|
|
|
|
# 可选:学习率调度
|
|
# if scheduler is not None:
|
|
# scheduler.step()
|
|
|
|
# 验证/测试阶段
|
|
val_loss, val_accuracy = evaluate(
|
|
model, val_test_loader, criterion, device, desc="验证/测试中"
|
|
)
|
|
print(
|
|
f"Epoch {epoch + 1} 验证/测试 Loss: {val_loss:.4f}, 准确率: {val_accuracy:.4f}"
|
|
)
|
|
|
|
# 保存最优模型
|
|
if val_accuracy > best_val_accuracy:
|
|
best_val_accuracy = val_accuracy
|
|
torch.save(model.state_dict(), SAVE_MODEL_PATH)
|
|
print(
|
|
f"保存了验证/测试集上最优的模型 (准确率: {best_val_accuracy:.4f}) 到 {SAVE_MODEL_PATH}"
|
|
)
|
|
|
|
print("\n训练完成!")
|
|
print(f"在验证/测试集上的最高准确率: {best_val_accuracy:.4f}")
|
|
|
|
# --- 可选: 加载并测试最优模型 ---
|
|
# 加载保存的最优模型进行最终测试 (如果验证/测试集是同一个,这就是最终结果)
|
|
print(f"\n加载最优模型 '{SAVE_MODEL_PATH}' 进行最终评估...")
|
|
loaded_model = models.mobilenet_v3_small(weights=None) # 先加载一个空的模型结构
|
|
loaded_model.classifier[-1] = nn.Linear(
|
|
loaded_model.classifier[-1].in_features, num_classes
|
|
) # 修改分类器
|
|
loaded_model.load_state_dict(
|
|
torch.load(SAVE_MODEL_PATH, map_location=device)
|
|
) # 加载权重
|
|
loaded_model = loaded_model.to(device)
|
|
|
|
final_test_loss, final_test_accuracy = evaluate(
|
|
loaded_model, val_test_loader, criterion, device, desc="最终测试中"
|
|
)
|
|
print(
|
|
f"\n最终测试 Loss: {final_test_loss:.4f}, 最终测试准确率: {final_test_accuracy:.4f}"
|
|
)
|
|
|
|
|
|
# --- 可选: 预测单张图片 ---
|
|
|
|
# 假设你想预测一张名为 '测试集/some_image_X_label.png' 的图片
|
|
def predict_single_image(image_path, model, class_to_idx, device, transform):
|
|
model.eval()
|
|
idx_to_class = {v: k for k, v in class_to_idx.items()}
|
|
|
|
try:
|
|
img = Image.open(image_path).convert("RGB")
|
|
img = transform(img).unsqueeze(0).to(device) # 添加 batch 维度并移动到设备
|
|
|
|
with torch.no_grad():
|
|
outputs = model(img)
|
|
probabilities = torch.softmax(outputs, dim=1)[0] # 获取概率分布
|
|
_, predicted_idx = torch.max(probabilities, 0) # 获取最高概率的索引
|
|
predicted_label = idx_to_class[predicted_idx.item()]
|
|
confidence = probabilities[predicted_idx].item()
|
|
|
|
print(f"\n预测图片: {image_path}")
|
|
print(f"预测类别: {predicted_label}, 置信度: {confidence:.4f}")
|
|
return predicted_label, confidence
|
|
|
|
except FileNotFoundError:
|
|
print(f"错误: 图片文件 '{image_path}' 未找到。")
|
|
return None, None
|
|
except Exception as e:
|
|
print(f"预测图片时发生错误: {e}")
|
|
return None, None
|
|
|
|
|
|
# # 示例预测 (取消注释以使用)
|
|
example_image_path = r"测试集\27_基础作战记录.png" # 替换为你测试集中的实际文件路径
|
|
predict_single_image(
|
|
example_image_path, loaded_model, class_to_idx, device, val_test_transforms
|
|
)
|