Project-I/trainer/trainer.py

520 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import os
import time
import copy
from tqdm import tqdm
import torch
class Trainer:
def __init__(self, config, model, loss, optimizer, train_loader, val_loader, test_loader,
scalers, logger, lr_scheduler=None):
self.model = model
self.loss = loss
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.scalers = scalers # 现在是多个标准化器的列表
self.args = config['train']
self.logger = logger
self.args['device'] = config['basic']['device']
self.lr_scheduler = lr_scheduler
self.train_per_epoch = len(train_loader)
self.val_per_epoch = len(val_loader) if val_loader else 0
self.best_path = os.path.join(logger.dir_path, 'best_model.pth')
self.best_test_path = os.path.join(logger.dir_path, 'best_test_model.pth')
self.loss_figure_path = os.path.join(logger.dir_path, 'loss.png')
def _run_epoch(self, epoch, dataloader, mode):
if mode == 'train':
self.model.train()
optimizer_step = True
else:
self.model.eval()
optimizer_step = False
total_loss = 0
epoch_time = time.time()
with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader):
label = target[..., :self.args['output_dim']]
output = self.model(data).to(self.args['device'])
if self.args['real_value']:
# 只对输出维度进行反归一化
output = self._inverse_transform_output(output)
loss = self.loss(output, label)
if optimizer_step and self.optimizer is not None:
self.optimizer.zero_grad()
loss.backward()
if self.args['grad_norm']:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
self.logger.info(
f'Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}')
# 更新 tqdm 的进度
pbar.update(1)
pbar.set_postfix(loss=loss.item())
avg_loss = total_loss / len(dataloader)
self.logger.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
return avg_loss
def _inverse_transform_output(self, output):
"""
只对输出维度进行反归一化
假设输出数据形状为 [batch, horizon, nodes, features]
只对前output_dim个特征进行反归一化
"""
if not self.args['real_value']:
return output
# 获取输出维度的数量
output_dim = self.args['output_dim']
# 如果输出特征数小于等于标准化器数量,直接使用对应的标准化器
if output_dim <= len(self.scalers):
# 对每个输出特征分别进行反归一化
for feature_idx in range(output_dim):
if feature_idx < len(self.scalers):
output[..., feature_idx:feature_idx+1] = self.scalers[feature_idx].inverse_transform(
output[..., feature_idx:feature_idx+1]
)
else:
# 如果输出特征数大于标准化器数量只对前len(scalers)个特征进行反归一化
for feature_idx in range(len(self.scalers)):
output[..., feature_idx:feature_idx+1] = self.scalers[feature_idx].inverse_transform(
output[..., feature_idx:feature_idx+1]
)
return output
def train_epoch(self, epoch):
return self._run_epoch(epoch, self.train_loader, 'train')
def val_epoch(self, epoch):
return self._run_epoch(epoch, self.val_loader or self.test_loader, 'val')
def test_epoch(self, epoch):
return self._run_epoch(epoch, self.test_loader, 'test')
def train(self):
best_model, best_test_model = None, None
best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0
self.logger.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch)
val_epoch_loss = self.val_epoch(epoch)
test_epoch_loss = self.test_epoch(epoch)
if train_epoch_loss > 1e6:
self.logger.logger.warning('Gradient explosion detected. Ending...')
break
if val_epoch_loss < best_loss:
best_loss = val_epoch_loss
not_improved_count = 0
best_model = copy.deepcopy(self.model.state_dict())
torch.save(best_model, self.best_path)
self.logger.logger.info('Best validation model saved!')
else:
not_improved_count += 1
if self.args['early_stop'] and not_improved_count == self.args['early_stop_patience']:
self.logger.logger.info(
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
break
if test_epoch_loss < best_test_loss:
best_test_loss = test_epoch_loss
best_test_model = copy.deepcopy(self.model.state_dict())
torch.save(best_test_model, self.best_test_path)
if not self.args['debug']:
torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path)
self.logger.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model)
self.logger.logger.info("Testing on best validation model")
self.test(self.model, self.args, self.test_loader, self.scalers, self.logger, generate_viz=False)
self.model.load_state_dict(best_test_model)
self.logger.logger.info("Testing on best test model")
self.test(self.model, self.args, self.test_loader, self.scalers, self.logger, generate_viz=True)
@staticmethod
def test(model, args, data_loader, scalers, logger, path=None, generate_viz=True):
if path:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict'])
model.to(args.device)
model.eval()
y_pred, y_true = [], []
with torch.no_grad():
for data, target in data_loader:
label = target[..., :args['output_dim']]
output = model(data)
y_pred.append(output)
y_true.append(label)
if args['real_value']:
# 只对输出维度进行反归一化
y_pred = Trainer._inverse_transform_output_static(torch.cat(y_pred, dim=0), args, scalers)
else:
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
# 计算每个时间步的指标
for t in range(y_true.shape[1]):
mae, rmse, mape = logger.all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
args['mae_thresh'], args['mape_thresh'])
logger.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
mae, rmse, mape = logger.all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
logger.logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
# 只在需要时生成可视化图片
if generate_viz:
save_dir = logger.dir_path if hasattr(logger, 'dir_path') else './logs'
Trainer._generate_node_visualizations(y_pred, y_true, logger, save_dir)
Trainer._generate_input_output_comparison(y_pred, y_true, data_loader, logger, save_dir,
target_node=1, num_samples=10, scalers=scalers)
@staticmethod
def _inverse_transform_output_static(output, args, scalers):
"""
静态方法:只对输出维度进行反归一化
"""
if not args['real_value']:
return output
# 获取输出维度的数量
output_dim = args['output_dim']
# 如果输出特征数小于等于标准化器数量,直接使用对应的标准化器
if output_dim <= len(scalers):
# 对每个输出特征分别进行反归一化
for feature_idx in range(output_dim):
if feature_idx < len(scalers):
output[..., feature_idx:feature_idx+1] = scalers[feature_idx].inverse_transform(
output[..., feature_idx:feature_idx+1]
)
else:
# 如果输出特征数大于标准化器数量只对前len(scalers)个特征进行反归一化
for feature_idx in range(len(scalers)):
output[..., feature_idx:feature_idx+1] = scalers[feature_idx].inverse_transform(
output[..., feature_idx:feature_idx+1]
)
return output
@staticmethod
def _generate_node_visualizations(y_pred, y_true, logger, save_dir):
"""
生成节点预测可视化图片
Args:
y_pred: 预测值
y_true: 真实值
logger: 日志记录器
save_dir: 保存目录
"""
import matplotlib.pyplot as plt
import numpy as np
import os
import matplotlib
from tqdm import tqdm
# 设置matplotlib配置减少字体查找输出
matplotlib.set_loglevel('error') # 只显示错误信息
plt.rcParams['font.family'] = 'DejaVu Sans' # 使用默认字体
# 检查数据有效性
if y_pred is None or y_true is None:
return
# 创建pic文件夹
pic_dir = os.path.join(save_dir, 'pic')
os.makedirs(pic_dir, exist_ok=True)
# 固定生成10张图片
num_nodes_to_plot = 10
# 生成单个节点的详细图
with tqdm(total=num_nodes_to_plot, desc="Generating node visualizations") as pbar:
for node_id in range(num_nodes_to_plot):
# 获取对应节点的数据
if len(y_pred.shape) > 2 and y_pred.shape[-2] > node_id:
# 数据格式: [time_step, seq_len, num_node, dim]
node_pred = y_pred[:, 12, node_id, 0].cpu().numpy() # t=1时刻指定节点第一个特征
node_true = y_true[:, 12, node_id, 0].cpu().numpy()
else:
# 如果数据不足10个节点只处理实际存在的节点
if node_id >= y_pred.shape[-2]:
pbar.update(1)
continue
else:
node_pred = y_pred[:, 0, node_id, 0].cpu().numpy()
node_true = y_true[:, 0, node_id, 0].cpu().numpy()
# 检查数据有效性
if np.isnan(node_pred).any() or np.isnan(node_true).any():
pbar.update(1)
continue
# 取前500个时间步
max_steps = min(500, len(node_pred))
if max_steps <= 0:
pbar.update(1)
continue
node_pred_500 = node_pred[:max_steps]
node_true_500 = node_true[:max_steps]
# 创建时间轴
time_steps = np.arange(max_steps)
# 绘制对比图
plt.figure(figsize=(12, 6))
plt.plot(time_steps, node_true_500, 'b-', label='True Values', linewidth=2, alpha=0.8)
plt.plot(time_steps, node_pred_500, 'r-', label='Predictions', linewidth=2, alpha=0.8)
plt.xlabel('Time Steps')
plt.ylabel('Values')
plt.title(f'Node {node_id + 1}: True vs Predicted Values (First {max_steps} Time Steps)')
plt.legend()
plt.grid(True, alpha=0.3)
# 保存图片,使用不同的命名
save_path = os.path.join(pic_dir, f'node{node_id + 1:02d}_prediction_first500.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
pbar.update(1)
# 生成所有节点的对比图前100个时间步便于观察
# 选择前100个时间步
plot_steps = min(100, y_pred.shape[0])
if plot_steps <= 0:
return
# 创建子图
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.flatten()
for node_id in range(num_nodes_to_plot):
if len(y_pred.shape) > 2 and y_pred.shape[-2] > node_id:
# 数据格式: [time_step, seq_len, num_node, dim]
node_pred = y_pred[:plot_steps, 0, node_id, 0].cpu().numpy()
node_true = y_true[:plot_steps, 0, node_id, 0].cpu().numpy()
else:
# 如果数据不足10个节点只处理实际存在的节点
if node_id >= y_pred.shape[-2]:
axes[node_id].text(0.5, 0.5, f'Node {node_id + 1}\nNo Data',
ha='center', va='center', transform=axes[node_id].transAxes)
continue
else:
node_pred = y_pred[:plot_steps, 0, node_id, 0].cpu().numpy()
node_true = y_true[:plot_steps, 0, node_id, 0].cpu().numpy()
# 检查数据有效性
if np.isnan(node_pred).any() or np.isnan(node_true).any():
axes[node_id].text(0.5, 0.5, f'Node {node_id + 1}\nNo Data',
ha='center', va='center', transform=axes[node_id].transAxes)
continue
time_steps = np.arange(plot_steps)
axes[node_id].plot(time_steps, node_true, 'b-', label='True', linewidth=1.5, alpha=0.8)
axes[node_id].plot(time_steps, node_pred, 'r-', label='Pred', linewidth=1.5, alpha=0.8)
axes[node_id].set_title(f'Node {node_id + 1}')
axes[node_id].grid(True, alpha=0.3)
axes[node_id].legend(fontsize=8)
if node_id >= 5: # 下面一行添加x轴标签
axes[node_id].set_xlabel('Time Steps')
if node_id % 5 == 0: # 左边一列添加y轴标签
axes[node_id].set_ylabel('Values')
plt.tight_layout()
summary_path = os.path.join(pic_dir, 'all_nodes_summary.png')
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
plt.close()
@staticmethod
def _generate_input_output_comparison(y_pred, y_true, data_loader, logger, save_dir,
target_node=1, num_samples=10, scalers=None):
"""
生成输入-输出样本比较图
Args:
y_pred: 预测值
y_true: 真实值
data_loader: 数据加载器,用于获取输入数据
logger: 日志记录器
save_dir: 保存目录
target_node: 目标节点ID从1开始
num_samples: 要比较的样本数量
scalers: 标准化器列表,用于反归一化输入数据
"""
import matplotlib.pyplot as plt
import numpy as np
import os
import matplotlib
from tqdm import tqdm
# 设置matplotlib配置
matplotlib.set_loglevel('error')
plt.rcParams['font.family'] = 'DejaVu Sans'
# 创建compare文件夹
compare_dir = os.path.join(save_dir, 'pic', 'compare')
os.makedirs(compare_dir, exist_ok=True)
# 获取输入数据
input_data = []
for batch_idx, (data, target) in enumerate(data_loader):
if batch_idx >= num_samples:
break
input_data.append(data.cpu().numpy())
if not input_data:
return
# 获取目标节点的索引从0开始
node_idx = target_node - 1
# 检查节点索引是否有效
if node_idx >= y_pred.shape[-2]:
return
# 为每个样本生成比较图
with tqdm(total=min(num_samples, len(input_data)), desc="Generating input-output comparisons") as pbar:
for sample_idx in range(min(num_samples, len(input_data))):
# 获取输入序列(假设输入形状为 [batch, seq_len, nodes, features]
input_seq = input_data[sample_idx][0, :, node_idx, 0] # 第一个batch所有时间步目标节点第一个特征
# 对输入数据进行反归一化
if scalers is not None and len(scalers) > 0:
# 使用第一个标准化器对输入进行反归一化(假设输入特征使用第一个标准化器)
input_seq = scalers[0].inverse_transform(input_seq.reshape(-1, 1)).flatten()
# 获取对应的预测值和真实值
pred_seq = y_pred[sample_idx, :, node_idx, 0].cpu().numpy() # 所有horizon目标节点第一个特征
true_seq = y_true[sample_idx, :, node_idx, 0].cpu().numpy()
# 检查数据有效性
if (np.isnan(input_seq).any() or np.isnan(pred_seq).any() or np.isnan(true_seq).any()):
pbar.update(1)
continue
# 创建时间轴 - 输入和输出连续
total_time = np.arange(len(input_seq) + len(pred_seq))
# 创建合并的图形 - 输入和输出在同一个图中
plt.figure(figsize=(14, 8))
# 绘制完整的真实值曲线(输入 + 真实输出)
true_combined = np.concatenate([input_seq, true_seq])
plt.plot(total_time, true_combined, 'b', label='True Values (Input + Output)',
linewidth=2.5, alpha=0.9, linestyle='-')
# 绘制预测值曲线(只绘制输出部分)
output_time = np.arange(len(input_seq), len(input_seq) + len(pred_seq))
plt.plot(output_time, pred_seq, 'r', label='Predicted Values',
linewidth=2, alpha=0.8, linestyle='-')
# 添加垂直线分隔输入和输出
plt.axvline(x=len(input_seq)-0.5, color='gray', linestyle=':', alpha=0.7,
label='Input/Output Boundary')
# 设置图形属性
plt.xlabel('Time Steps')
plt.ylabel('Values')
plt.title(f'Sample {sample_idx + 1}: Input-Output Comparison (Node {target_node})')
plt.legend()
plt.grid(True, alpha=0.3)
# 调整布局
plt.tight_layout()
# 保存图片
save_path = os.path.join(compare_dir, f'sample{sample_idx + 1:02d}_node{target_node:02d}_comparison.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
pbar.update(1)
# 生成汇总图(所有样本的预测值对比)
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.flatten()
for sample_idx in range(min(num_samples, len(input_data))):
if sample_idx >= 10: # 最多显示10个子图
break
ax = axes[sample_idx]
# 获取输入序列和预测值、真实值
input_seq = input_data[sample_idx][0, :, node_idx, 0]
if scalers is not None and len(scalers) > 0:
input_seq = scalers[0].inverse_transform(input_seq.reshape(-1, 1)).flatten()
pred_seq = y_pred[sample_idx, :, node_idx, 0].cpu().numpy()
true_seq = y_true[sample_idx, :, node_idx, 0].cpu().numpy()
# 检查数据有效性
if np.isnan(input_seq).any() or np.isnan(pred_seq).any() or np.isnan(true_seq).any():
ax.text(0.5, 0.5, f'Sample {sample_idx + 1}\nNo Data',
ha='center', va='center', transform=ax.transAxes)
continue
# 绘制对比图 - 输入和输出连续显示
total_time = np.arange(len(input_seq) + len(pred_seq))
true_combined = np.concatenate([input_seq, true_seq])
output_time = np.arange(len(input_seq), len(input_seq) + len(pred_seq))
ax.plot(total_time, true_combined, 'b', label='True', linewidth=2, alpha=0.9, linestyle='-')
ax.plot(output_time, pred_seq, 'r', label='Pred', linewidth=1.5, alpha=0.8, linestyle='-')
ax.axvline(x=len(input_seq)-0.5, color='gray', linestyle=':', alpha=0.5)
ax.set_title(f'Sample {sample_idx + 1}')
ax.grid(True, alpha=0.3)
ax.legend(fontsize=8)
if sample_idx >= 5: # 下面一行添加x轴标签
ax.set_xlabel('Time Steps')
if sample_idx % 5 == 0: # 左边一列添加y轴标签
ax.set_ylabel('Values')
# 隐藏多余的子图
for i in range(min(num_samples, len(input_data)), 10):
axes[i].set_visible(False)
plt.tight_layout()
summary_path = os.path.join(compare_dir, f'all_samples_node{target_node:02d}_summary.png')
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
plt.close()
@staticmethod
def _compute_sampling_threshold(global_step, k):
return k / (k + math.exp(global_step / k))