TrafficWheel/trainer/STEP_Trainer.py

352 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import os
import time
import copy
import psutil
from tqdm import tqdm
import torch
from lib.logger import get_logger
from lib.loss_function import all_metrics
from model.STEP.step_loss import step_loss
class TrainingStats:
def __init__(self, device):
self.device = device
self.reset()
def reset(self):
self.gpu_mem_usage_list = []
self.cpu_mem_usage_list = []
self.train_time_list = []
self.infer_time_list = []
self.total_iters = 0
self.start_time = None
self.end_time = None
def start_training(self):
self.start_time = time.time()
def end_training(self):
self.end_time = time.time()
def record_step_time(self, duration, mode):
"""记录单步耗时和总迭代次数"""
if mode == 'train':
self.train_time_list.append(duration)
else:
self.infer_time_list.append(duration)
self.total_iters += 1
def record_memory_usage(self):
"""记录当前 GPU 和 CPU 内存占用"""
process = psutil.Process(os.getpid())
cpu_mem = process.memory_info().rss / (1024 ** 2)
if torch.cuda.is_available():
gpu_mem = torch.cuda.max_memory_allocated(device=self.device) / (1024 ** 2)
torch.cuda.reset_peak_memory_stats(device=self.device)
else:
gpu_mem = 0.0
self.cpu_mem_usage_list.append(cpu_mem)
self.gpu_mem_usage_list.append(gpu_mem)
def report(self, logger):
"""在训练结束时输出汇总统计"""
if not self.start_time or not self.end_time:
logger.warning("TrainingStats: start/end time not recorded properly.")
return
total_time = self.end_time - self.start_time
avg_gpu_mem = sum(self.gpu_mem_usage_list) / len(self.gpu_mem_usage_list) if self.gpu_mem_usage_list else 0
avg_cpu_mem = sum(self.cpu_mem_usage_list) / len(self.cpu_mem_usage_list) if self.cpu_mem_usage_list else 0
avg_train_time = sum(self.train_time_list) / len(self.train_time_list) if self.train_time_list else 0
avg_infer_time = sum(self.infer_time_list) / len(self.infer_time_list) if self.infer_time_list else 0
iters_per_sec = self.total_iters / total_time if total_time > 0 else 0
logger.info("===== Training Summary =====")
logger.info(f"Total training time: {total_time:.2f} s")
logger.info(f"Total iterations: {self.total_iters}")
logger.info(f"Average iterations per second: {iters_per_sec:.2f}")
logger.info(f"Average GPU Memory Usage: {avg_gpu_mem:.2f} MB")
logger.info(f"Average CPU Memory Usage: {avg_cpu_mem:.2f} MB")
if avg_train_time:
logger.info(f"Average training step time: {avg_train_time*1000:.2f} ms")
if avg_infer_time:
logger.info(f"Average inference step time: {avg_infer_time*1000:.2f} ms")
class Trainer:
def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader,
scaler, args, lr_scheduler=None):
self.model = model
self.loss = loss
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.scaler = scaler
self.args = args
self.lr_scheduler = lr_scheduler
self.train_per_epoch = len(train_loader)
self.val_per_epoch = len(val_loader) if val_loader else 0
# Paths for saving models and logs
log_dir = args.get('log_dir', './logs/STEP')
os.makedirs(log_dir, exist_ok=True) # 确保目录存在
self.best_path = os.path.join(log_dir, 'best_model.pth')
self.best_test_path = os.path.join(log_dir, 'best_test_model.pth')
self.loss_figure_path = os.path.join(log_dir, 'loss.png')
# Initialize logger
log_dir = args.get('log_dir', './logs/STEP')
self.logger = get_logger(log_dir, name='STEP_Trainer')
# Initialize training stats
self.device = next(model.parameters()).device
self.stats = TrainingStats(self.device)
def train_epoch(self, epoch):
self.model.train()
total_loss = 0
total_metrics = {}
with tqdm(self.train_loader, desc=f'Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(pbar):
start_time = time.time()
data = data.to(self.device)
target = target.to(self.device)
self.optimizer.zero_grad()
# STEP模型的前向传播
output = self.model(data)
# 计算损失这里需要根据STEP模型的具体输出调整
# STEP模型返回多个输出包括预测值、Bernoulli参数等
if isinstance(output, tuple):
prediction = output[0]
# 如果模型返回了其他参数,可以在这里处理
else:
prediction = output
# 使用标准损失函数
if callable(self.loss) and hasattr(self.loss, '__call__'):
# 如果是一个可调用对象比如masked_mae_loss返回的函数
if hasattr(self.loss, 'func_name') or 'function' in str(type(self.loss)):
loss_fn = self.loss(None, None) # 创建实际的损失函数
loss = loss_fn(prediction, target)
else:
loss = self.loss(prediction, target)
else:
# 如果是PyTorch的损失函数
loss = self.loss(prediction, target)
loss.backward()
# 梯度裁剪
if self.args.get('clip_grad_norm', 0) > 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['clip_grad_norm'])
self.optimizer.step()
# 记录统计信息
step_time = time.time() - start_time
self.stats.record_step_time(step_time, 'train')
total_loss += loss.item()
# 计算指标
mae, rmse, mape = all_metrics(prediction, target, None, 0.0)
metrics = {'mae': mae.item(), 'rmse': rmse.item(), 'mape': mape.item()}
for key, value in metrics.items():
if key not in total_metrics:
total_metrics[key] = 0
total_metrics[key] += value
# 更新进度条
pbar.set_postfix({
'Loss': f'{loss.item():.4f}',
'MAE': f'{metrics.get("mae", 0):.4f}',
'RMSE': f'{metrics.get("rmse", 0):.4f}'
})
# 记录内存使用
if batch_idx % 100 == 0:
self.stats.record_memory_usage()
# 计算平均损失和指标
avg_loss = total_loss / len(self.train_loader)
avg_metrics = {key: value / len(self.train_loader) for key, value in total_metrics.items()}
return avg_loss, avg_metrics
def val_epoch(self, epoch):
self.model.eval()
total_loss = 0
total_metrics = {}
with torch.no_grad():
with tqdm(self.val_loader, desc=f'Validation {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(pbar):
start_time = time.time()
data = data.to(self.device)
target = target.to(self.device)
# STEP模型的前向传播
output = self.model(data)
if isinstance(output, tuple):
prediction = output[0]
else:
prediction = output
# 计算损失
if callable(self.loss) and hasattr(self.loss, '__call__'):
# 如果是一个可调用对象比如masked_mae_loss返回的函数
if hasattr(self.loss, 'func_name') or 'function' in str(type(self.loss)):
loss_fn = self.loss(None, None) # 创建实际的损失函数
loss = loss_fn(prediction, target)
else:
loss = self.loss(prediction, target)
else:
# 如果是PyTorch的损失函数
loss = self.loss(prediction, target)
# 记录统计信息
step_time = time.time() - start_time
self.stats.record_step_time(step_time, 'val')
total_loss += loss.item()
# 计算指标
mae, rmse, mape = all_metrics(prediction, target, None, 0.0)
metrics = {'mae': mae.item(), 'rmse': rmse.item(), 'mape': mape.item()}
for key, value in metrics.items():
if key not in total_metrics:
total_metrics[key] = 0
total_metrics[key] += value
# 更新进度条
pbar.set_postfix({
'Loss': f'{loss.item():.4f}',
'MAE': f'{metrics.get("mae", 0):.4f}',
'RMSE': f'{metrics.get("rmse", 0):.4f}'
})
# 计算平均损失和指标
avg_loss = total_loss / len(self.val_loader)
avg_metrics = {key: value / len(self.val_loader) for key, value in total_metrics.items()}
return avg_loss, avg_metrics
def test_epoch(self, epoch):
self.model.eval()
total_loss = 0
total_metrics = {}
with torch.no_grad():
with tqdm(self.test_loader, desc=f'Test {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(pbar):
start_time = time.time()
data = data.to(self.device)
target = target.to(self.device)
# STEP模型的前向传播
output = self.model(data)
if isinstance(output, tuple):
prediction = output[0]
else:
prediction = output
# 计算损失
if callable(self.loss) and hasattr(self.loss, '__call__'):
# 如果是一个可调用对象比如masked_mae_loss返回的函数
if hasattr(self.loss, 'func_name') or 'function' in str(type(self.loss)):
loss_fn = self.loss(None, None) # 创建实际的损失函数
loss = loss_fn(prediction, target)
else:
loss = self.loss(prediction, target)
else:
# 如果是PyTorch的损失函数
loss = self.loss(prediction, target)
# 记录统计信息
step_time = time.time() - start_time
self.stats.record_step_time(step_time, 'test')
total_loss += loss.item()
# 计算指标
mae, rmse, mape = all_metrics(prediction, target, None, 0.0)
metrics = {'mae': mae.item(), 'rmse': rmse.item(), 'mape': mape.item()}
for key, value in metrics.items():
if key not in total_metrics:
total_metrics[key] = 0
total_metrics[key] += value
# 更新进度条
pbar.set_postfix({
'Loss': f'{loss.item():.4f}',
'MAE': f'{metrics.get("mae", 0):.4f}',
'RMSE': f'{metrics.get("rmse", 0):.4f}'
})
# 计算平均损失和指标
avg_loss = total_loss / len(self.test_loader)
avg_metrics = {key: value / len(self.test_loader) for key, value in total_metrics.items()}
return avg_loss, avg_metrics
def train(self):
self.stats.start_training()
best_val_loss = float('inf')
best_test_loss = float('inf')
for epoch in range(self.args['epochs']):
# 训练
train_loss, train_metrics = self.train_epoch(epoch)
# 验证
if self.val_loader:
val_loss, val_metrics = self.val_epoch(epoch)
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(self.model.state_dict(), self.best_path)
self.logger.info(f'Epoch {epoch}: Best validation loss: {val_loss:.4f}')
# 测试
if self.test_loader:
test_loss, test_metrics = self.test_epoch(epoch)
# 保存最佳测试模型
if test_loss < best_test_loss:
best_test_loss = test_loss
torch.save(self.model.state_dict(), self.best_test_path)
self.logger.info(f'Epoch {epoch}: Best test loss: {test_loss:.4f}')
# 学习率调度
if self.lr_scheduler:
self.lr_scheduler.step()
# 记录日志
self.logger.info(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Train MAE: {train_metrics.get("mae", 0):.4f}')
if self.val_loader:
self.logger.info(f'Epoch {epoch}: Val Loss: {val_loss:.4f}, Val MAE: {val_metrics.get("mae", 0):.4f}')
if self.test_loader:
self.logger.info(f'Epoch {epoch}: Test Loss: {test_loss:.4f}, Test MAE: {test_metrics.get("mae", 0):.4f}')
self.stats.end_training()
self.stats.report(self.logger)
return best_val_loss, best_test_loss