119 lines
3.7 KiB
Python
119 lines
3.7 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
STEP模型训练脚本
|
||
"""
|
||
|
||
import torch
|
||
import yaml
|
||
import os
|
||
import sys
|
||
import argparse
|
||
|
||
# 添加项目根目录到路径
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
from model.model_selector import model_selector
|
||
from dataloader.loader_selector import get_dataloader
|
||
from trainer.trainer_selector import select_trainer
|
||
from lib.loss_function import masked_mae_loss
|
||
|
||
def train_step_model(config_path, epochs=None):
|
||
"""训练STEP模型"""
|
||
print(f"开始训练STEP模型,配置文件: {config_path}")
|
||
|
||
# 加载配置
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
# 如果指定了epochs,覆盖配置文件中的设置
|
||
if epochs is not None:
|
||
config['train']['epochs'] = epochs
|
||
|
||
# 设置设备
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
print(f"使用设备: {device}")
|
||
|
||
# 创建日志目录
|
||
log_dir = f'./logs/STEP_{config["data"]["type"]}'
|
||
os.makedirs(log_dir, exist_ok=True)
|
||
|
||
try:
|
||
# 创建模型
|
||
print("创建STEP模型...")
|
||
model = model_selector(config['model'])
|
||
model = model.to(device)
|
||
print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
|
||
|
||
# 创建数据加载器
|
||
print("创建数据加载器...")
|
||
train_loader, val_loader, test_loader, scaler = get_dataloader(
|
||
config, normalizer='std', single=True
|
||
)
|
||
print(f"训练集批次数: {len(train_loader)}")
|
||
print(f"验证集批次数: {len(val_loader)}")
|
||
print(f"测试集批次数: {len(test_loader)}")
|
||
|
||
# 创建优化器
|
||
print("创建优化器...")
|
||
optimizer = torch.optim.Adam(
|
||
model.parameters(),
|
||
lr=config['train']['lr_init'],
|
||
weight_decay=config['train']['weight_decay']
|
||
)
|
||
|
||
# 创建学习率调度器
|
||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||
optimizer,
|
||
milestones=config['train']['lr_decay_step'],
|
||
gamma=config['train']['lr_decay_rate']
|
||
)
|
||
|
||
# 创建训练器
|
||
print("创建训练器...")
|
||
trainer = select_trainer(
|
||
model=model,
|
||
loss=masked_mae_loss,
|
||
optimizer=optimizer,
|
||
train_loader=train_loader,
|
||
val_loader=val_loader,
|
||
test_loader=test_loader,
|
||
scaler=scaler,
|
||
args=config,
|
||
lr_scheduler=lr_scheduler,
|
||
kwargs=[]
|
||
)
|
||
|
||
# 开始训练
|
||
print(f"开始训练,总epochs: {config['train']['epochs']}")
|
||
best_val_loss, best_test_loss = trainer.train()
|
||
|
||
print(f"训练完成!")
|
||
print(f"最佳验证损失: {best_val_loss:.4f}")
|
||
print(f"最佳测试损失: {best_test_loss:.4f}")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"STEP模型训练失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='训练STEP模型')
|
||
parser.add_argument('--config', type=str, default='config/STEP/STEP_PEMS04.yaml',
|
||
help='配置文件路径')
|
||
parser.add_argument('--epochs', type=int, default=None,
|
||
help='训练轮数(覆盖配置文件中的设置)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
success = train_step_model(args.config, args.epochs)
|
||
if success:
|
||
print("\n✅ STEP模型训练完成!")
|
||
else:
|
||
print("\n❌ STEP模型训练失败!")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|