新增多个PEMS数据集配置文件,包含PEMSD3、PEMSD4、PEMSD7、PEMSD8及STAWnet、STGNRDE、ST_SSL模型的相关配置,优化模型训练参数设置。

This commit is contained in:
czzhangheng 2025-08-19 00:07:37 +08:00
parent b820b867fb
commit 29fd709c8c
17 changed files with 216 additions and 34 deletions

View File

@ -32,7 +32,7 @@ pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio
python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0}
```
- model_name: 目前支持DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN
- model_name: 目前支持DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN、STAWnet
- dataset_name目前支持PEMSD3PEMSD4、PEMSD7、PEMSD8
- modetrain为训练模型test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。
- device: 支持'cpu'、'cuda:0'、cuda:1 ... 取决于机器卡数

View File

@ -6,7 +6,7 @@ data:
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
column_wise: True
default_graph: True
add_time_in_day: True
add_day_in_week: True
@ -21,26 +21,26 @@ model:
max_diffusion_step: 2
cl_decay_steps: 1000
filter_type: dual_random_walk
num_rnn_layers: 1
num_rnn_layers: 2
rnn_units: 64
seq_len: 12
use_curriculum_learning: True
train:
loss_func: mae
loss_func: mask_mae
seed: 10
batch_size: 64
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.001
weight_decay: 0.0001
lr_decay: True
lr_decay_rate: 0.1
lr_decay_step: "10,20,40,80"
early_stop: True
early_stop_patience: 15
grad_norm: False
early_stop_patience: 25
grad_norm: True
max_grad_norm: 5
real_value: True
real_value: False
test:
mae_thresh: null

View File

@ -2,10 +2,12 @@ from dataloader.cde_loader.cdeDataloader import get_dataloader as cde_loader
from dataloader.PeMSDdataloader import get_dataloader as normal_loader
from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader
from dataloader.EXPdataloader import get_dataloader as EXP_loader
from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader
def get_dataloader(config, normalizer, single):
match config['model']['type']:
case 'STGNCDE': return cde_loader(config['data'], normalizer, single)
case 'STGNRDE': return nrde_loader(config['data'], normalizer, single)
case 'DCRNN': return DCRNN_loader(config['data'], normalizer, single)
case 'EXP': return EXP_loader(config['data'], normalizer, single)
case _: return normal_loader(config['data'], normalizer, single)

View File

@ -1,9 +1,9 @@
def masked_mae_loss(scaler, mask_value):
def loss(preds, labels):
# 仅对预测反归一化;标签在数据管道中保持原始量纲
if scaler:
preds = scaler.inverse_transform(preds)
labels = scaler.inverse_transform(labels)
return mae_torch(pred=preds, true=labels, mask_value=mask_value)
return loss
@ -11,7 +11,8 @@ def masked_mae_loss(scaler, mask_value):
def get_loss_function(args, scaler):
if args['loss_func'] == 'mask_mae':
return masked_mae_loss(scaler, mask_value=0.0).to(args['device'])
# Return callable loss (no .to for function closures); disable masking by default
return masked_mae_loss(scaler, mask_value=None)
elif args['loss_func'] == 'mae':
return torch.nn.L1Loss().to(args['device'])
elif args['loss_func'] == 'mse':

View File

@ -34,7 +34,7 @@ class LayerParams:
class DCGRUCell(torch.nn.Module):
def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, nonlinearity='tanh',
def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, input_dim=None, nonlinearity='tanh',
filter_type="laplacian", use_gc_for_ru=True):
"""
@ -55,6 +55,7 @@ class DCGRUCell(torch.nn.Module):
self._max_diffusion_step = max_diffusion_step
self._supports = []
self._use_gc_for_ru = use_gc_for_ru
self._input_dim = input_dim # optional; if None, will be inferred at first forward
supports = []
if filter_type == "laplacian":
supports.append(utils.calculate_scaled_laplacian(adj_mx, lambda_max=None))
@ -71,6 +72,22 @@ class DCGRUCell(torch.nn.Module):
self._fc_params = LayerParams(self, 'fc')
self._gconv_params = LayerParams(self, 'gconv')
# Pre-create parameters if input_dim is known
if self._input_dim is not None:
num_matrices = len(self._supports) * self._max_diffusion_step + 1
input_size = self._input_dim + self._num_units
# FC weights/biases for RU gates (2 * num_units)
self._fc_params.get_weights((input_size, 2 * self._num_units))
self._fc_params.get_biases(2 * self._num_units, bias_start=1.0)
# Optionally for candidate (num_units) if FC path is used
self._fc_params.get_weights((input_size, self._num_units))
self._fc_params.get_biases(self._num_units, bias_start=0.0)
# GConv weights/biases for RU gates and candidate
self._gconv_params.get_weights((input_size * num_matrices, 2 * self._num_units))
self._gconv_params.get_biases(2 * self._num_units, bias_start=1.0)
self._gconv_params.get_weights((input_size * num_matrices, self._num_units))
self._gconv_params.get_biases(self._num_units, bias_start=0.0)
@staticmethod
def _build_sparse_matrix(L):
L = L.tocoo()

View File

@ -15,6 +15,10 @@ class Seq2SeqAttrs:
self.num_nodes = args.get('num_nodes', 1)
self.num_rnn_layers = args.get('num_rnn_layers', 1)
self.rnn_units = args.get('rnn_units')
self.input_dim = args.get('input_dim', 1)
self.output_dim = args.get('output_dim', 1)
self.horizon = args.get('horizon', 12)
self.seq_len = args.get('seq_len', 12)
self.hidden_state_size = self.num_nodes * self.rnn_units
@ -26,7 +30,7 @@ class EncoderModel(nn.Module, Seq2SeqAttrs):
self.seq_len = args.get('seq_len') # for the encoder
self.dcgru_layers = nn.ModuleList(
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
input_dim=self.input_dim, filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
def forward(self, inputs, hidden_state=None):
"""
@ -63,7 +67,7 @@ class DecoderModel(nn.Module, Seq2SeqAttrs):
self.projection_layer = nn.Linear(self.rnn_units, self.output_dim)
self.dcgru_layers = nn.ModuleList(
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
input_dim=self.output_dim, filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
def forward(self, inputs, hidden_state=None):
"""
@ -146,17 +150,19 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
def forward(self, inputs, labels=None):
"""
seq2seq forward pass 64 12 307 3
:param inputs: shape (seq_len, batch_size, num_sensor * input_dim) 12 64 307 * 1
:param labels: shape (horizon, batch_size, num_sensor * output) 12 64 307 1
:param batches_seen: batches seen till now
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
seq2seq forward pass. inputs: [B, T, N, C]
"""
inputs = inputs[..., 0].permute(1, 0, 2)
labels = labels[..., 0].permute(1, 0, 2)
encoder_hidden_state = self.encoder(inputs)
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=self.batch_seen)
x = inputs[..., :self.input_dim]
x = x.permute(1, 0, 2, 3).contiguous().view(self.seq_len, -1, self.num_nodes * self.input_dim)
y = None
if labels is not None:
y = labels[..., :self.output_dim]
y = y.permute(1, 0, 2, 3).contiguous().view(self.horizon, -1, self.num_nodes * self.output_dim)
encoder_hidden_state = self.encoder(x)
outputs = self.decoder(encoder_hidden_state, y, batches_seen=self.batch_seen)
self.batch_seen += 1
outputs = outputs.unsqueeze(dim=-1) # [12,64,307,1]
outputs = outputs.permute(1, 0, 2, 3) # [64,12,307,1]
outputs = outputs.view(self.horizon, -1, self.num_nodes, self.output_dim)
outputs = outputs.permute(1, 0, 2, 3).contiguous()
return outputs

View File

@ -7,6 +7,7 @@ import numpy as np
from lib.logger import get_logger
from lib.metrics import All_Metrics
from lib.TrainInits import print_model_parameters
from lib.training_stats import TrainingStats
class Trainer(object):
def __init__(self, model, vector_field_f, vector_field_g, loss, optimizer, train_loader, val_loader, test_loader,
@ -42,6 +43,8 @@ class Trainer(object):
self.device = device
self.times = times.to(self.device, dtype=torch.float)
self.w = w
# Stats tracker
self.stats = TrainingStats(device=device)
def val_epoch(self, epoch, val_dataloader):
self.model.eval()
@ -49,6 +52,7 @@ class Trainer(object):
with torch.no_grad():
for batch_idx, batch in enumerate(self.val_loader):
start_time = time.time()
# for iter, batch in enumerate(val_dataloader):
batch = tuple(b.to(self.device, dtype=torch.float) for b in batch)
*valid_coeffs, target = batch
@ -61,8 +65,11 @@ class Trainer(object):
#a whole batch of Metr_LA is filtered
if not torch.isnan(loss):
total_val_loss += loss.item()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, 'val')
val_loss = total_val_loss / len(val_dataloader)
self.logger.info('**********Val Epoch {}: average Loss: {:.6f}'.format(epoch, val_loss))
self.stats.record_memory_usage()
if self.args.tensorboard:
self.w.add_scalar(f'valid/loss', val_loss, epoch)
return val_loss
@ -73,6 +80,7 @@ class Trainer(object):
# for batch_idx, (data, target) in enumerate(self.train_loader):
# for batch_idx, (data, target) in enumerate(self.train_loader):
for batch_idx, batch in enumerate(self.train_loader):
start_time = time.time()
batch = tuple(b.to(self.device, dtype=torch.float) for b in batch)
*train_coeffs, target = batch
# data = data[..., :self.args.input_dim]
@ -101,6 +109,8 @@ class Trainer(object):
if self.args.grad_norm:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
self.optimizer.step()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, 'train')
total_loss += loss.item()
#log information
@ -109,6 +119,7 @@ class Trainer(object):
epoch, batch_idx, self.train_per_epoch, loss.item()))
train_epoch_loss = total_loss/self.train_per_epoch
self.logger.info('**********Train Epoch {}: averaged Loss: {:.6f}'.format(epoch, train_epoch_loss))
self.stats.record_memory_usage()
if self.args.tensorboard:
self.w.add_scalar(f'train/loss', train_epoch_loss, epoch)
@ -123,6 +134,7 @@ class Trainer(object):
not_improved_count = 0
train_loss_list = []
val_loss_list = []
self.stats.start_training()
start_time = time.time()
for epoch in range(1, self.args.epochs + 1):
#epoch_time = time.time()
@ -168,6 +180,13 @@ class Trainer(object):
training_time = time.time() - start_time
self.logger.info("Total training time: {:.4f}min, best loss: {:.6f}".format((training_time / 60), best_loss))
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
#save the best model to file
if not self.args.debug:

View File

@ -18,6 +18,10 @@ from model.STIDGCN.STIDGCN import STIDGCN
from model.STID.STID import STID
from model.STAEFormer.STAEFormer import STAEformer
from model.EXP.EXP32 import EXP as EXP
from model.MegaCRN.MegaCRNModel import MegaCRNModel
from model.ST_SSL.ST_SSL import STSSLModel
from model.STGNRDE.Make_model import make_model as make_nrde_model
from model.STAWnet.STAWnet import STAWnet
def model_selector(model):
match model['type']:
@ -41,4 +45,8 @@ def model_selector(model):
case 'STID': return STID(model)
case 'STAEFormer': return STAEformer(model)
case 'EXP': return EXP(model)
case 'MegaCRN': return MegaCRNModel(model)
case 'ST_SSL': return STSSLModel(model)
case 'STGNRDE': return make_nrde_model(model)
case 'STAWnet': return STAWnet(model)

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
import torch
from lib.logger import get_logger
from lib.loss_function import all_metrics
from lib.training_stats import TrainingStats
class Trainer:
@ -34,6 +35,8 @@ class Trainer:
os.makedirs(args['log_dir'], exist_ok=True)
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
self.logger.info(f"Experiment log path in: {args['log_dir']}")
# Stats tracker
self.stats = TrainingStats(device=args['device'])
def _run_epoch(self, epoch, dataloader, mode):
if mode == 'train':
@ -49,12 +52,13 @@ class Trainer:
with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader):
start_time = time.time()
label = target[..., :self.args['output_dim']]
# label = target[..., :self.args['output_dim']]
output = self.model(data, labels=label.clone()).to(self.args['device'])
if self.args['real_value']:
output = self.scaler.inverse_transform(output)
label = self.scaler.inverse_transform(label)
loss = self.loss(output, label)
if optimizer_step and self.optimizer is not None:
@ -65,6 +69,8 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, mode)
total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
@ -78,6 +84,8 @@ class Trainer:
avg_loss = total_loss / len(dataloader)
self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 记录内存
self.stats.record_memory_usage()
return avg_loss
def train_epoch(self, epoch):
@ -94,6 +102,7 @@ class Trainer:
best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0
self.stats.start_training()
self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch)
@ -126,6 +135,14 @@ class Trainer:
torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
# 输出统计与参数
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model):
@ -154,11 +171,11 @@ class Trainer:
y_pred.append(output)
y_true.append(label)
if args['real_value']:
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
else:
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
if args['real_value']:
y_pred = scaler.inverse_transform(y_pred)
y_true = scaler.inverse_transform(y_true)
for t in range(y_true.shape[1]):
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
import torch
from lib.logger import get_logger
from lib.loss_function import all_metrics
from lib.training_stats import TrainingStats
class Trainer:
@ -34,6 +35,8 @@ class Trainer:
os.makedirs(args['log_dir'], exist_ok=True)
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
self.logger.info(f"Experiment log path in: {args['log_dir']}")
# Stats tracker
self.stats = TrainingStats(device=args['device'])
def _run_epoch(self, epoch, dataloader, mode):
is_train = (mode == 'train')
@ -45,6 +48,7 @@ class Trainer:
tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, batch in enumerate(dataloader):
start_time = time.time()
# unpack the new cycle_index
data, target, cycle_index = batch
data = data.to(self.args['device'])
@ -72,6 +76,8 @@ class Trainer:
self.args['max_grad_norm'])
self.optimizer.step()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, mode)
total_loss += loss.item()
# logging
@ -86,6 +92,8 @@ class Trainer:
avg_loss = total_loss / len(dataloader)
self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 记录内存
self.stats.record_memory_usage()
return avg_loss
def train_epoch(self, epoch):
@ -102,6 +110,7 @@ class Trainer:
best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0
self.stats.start_training()
self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch)
@ -134,6 +143,14 @@ class Trainer:
torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
# 输出统计与参数
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model):

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
import torch
from lib.logger import get_logger
from lib.loss_function import all_metrics
from lib.training_stats import TrainingStats
class Trainer:
@ -34,6 +35,8 @@ class Trainer:
os.makedirs(args['log_dir'], exist_ok=True)
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
self.logger.info(f"Experiment log path in: {args['log_dir']}")
# Stats tracker
self.stats = TrainingStats(device=args['device'])
def _run_epoch(self, epoch, dataloader, mode):
if mode == 'train':
@ -49,6 +52,7 @@ class Trainer:
with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader):
start_time = time.time()
label = target[..., :self.args['output_dim']]
output = self.model(data).to(self.args['device'])
@ -64,6 +68,9 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, mode)
total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
@ -77,6 +84,8 @@ class Trainer:
avg_loss = total_loss / len(dataloader)
self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 记录内存
self.stats.record_memory_usage()
return avg_loss
def train_epoch(self, epoch):
@ -93,6 +102,7 @@ class Trainer:
best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0
self.stats.start_training()
self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch)
@ -124,7 +134,14 @@ class Trainer:
torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
# 输出统计与参数
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model):

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
import torch
from lib.logger import get_logger
from lib.loss_function import all_metrics
from lib.training_stats import TrainingStats
class Trainer:
@ -35,6 +36,8 @@ class Trainer:
os.makedirs(args['log_dir'], exist_ok=True)
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
self.logger.info(f"Experiment log path in: {args['log_dir']}")
# Stats tracker
self.stats = TrainingStats(device=args['device'])
def _run_epoch(self, epoch, dataloader, mode):
if mode == 'train':
@ -50,6 +53,7 @@ class Trainer:
with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader):
start_time = time.time()
self.batches_seen += 1
label = target[..., :self.args['output_dim']].clone()
output = self.model(data, target, self.batches_seen).to(self.args['device'])
@ -66,6 +70,9 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
# record step time
step_time = time.time() - start_time
self.stats.record_step_time(step_time, mode)
total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
@ -79,6 +86,8 @@ class Trainer:
avg_loss = total_loss / len(dataloader)
self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 记录内存
self.stats.record_memory_usage()
return avg_loss
def train_epoch(self, epoch):
@ -95,6 +104,7 @@ class Trainer:
best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0
self.stats.start_training()
self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch)
@ -127,6 +137,14 @@ class Trainer:
torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
# 输出统计与参数
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model):

View File

@ -11,6 +11,7 @@ from tqdm import tqdm
from lib.logger import get_logger
from lib.loss_function import all_metrics
from model.STMLP.STMLP import STMLP
from lib.training_stats import TrainingStats
class Trainer:
@ -43,6 +44,8 @@ class Trainer:
os.makedirs(self.pretrain_dir, exist_ok=True)
self.logger = get_logger(self.args['log_dir'], name=self.model.__class__.__name__, debug=self.args['debug'])
self.logger.info(f"Experiment log path in: {self.args['log_dir']}")
# Stats tracker
self.stats = TrainingStats(device=args['device'])
if self.args['teacher_stu']:
self.tmodel = self.loadTeacher(args)
@ -67,6 +70,7 @@ class Trainer:
with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader):
start_time = time.time()
if self.args['teacher_stu']:
label = target[..., :self.args['output_dim']]
output, out_, _ = self.model(data)
@ -100,6 +104,8 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, mode)
total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
@ -113,6 +119,8 @@ class Trainer:
avg_loss = total_loss / len(dataloader)
self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 记录内存
self.stats.record_memory_usage()
return avg_loss
def train_epoch(self, epoch):
@ -129,6 +137,7 @@ class Trainer:
best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0
self.stats.start_training()
self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch)
@ -165,6 +174,14 @@ class Trainer:
torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
# 输出统计与参数
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model):

View File

@ -211,6 +211,13 @@ class Trainer:
self._finalize_training(best_model, best_test_model)
# 输出参数量
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model)
self.logger.info("Testing on best validation model")

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
import torch
from lib.logger import get_logger
from lib.loss_function import all_metrics
from lib.training_stats import TrainingStats
class Trainer:
@ -34,6 +35,8 @@ class Trainer:
os.makedirs(args['log_dir'], exist_ok=True)
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
self.logger.info(f"Experiment log path in: {args['log_dir']}")
# Stats tracker
self.stats = TrainingStats(device=args['device'])
def _run_epoch(self, epoch, dataloader, mode):
if mode == 'train':
@ -49,6 +52,7 @@ class Trainer:
with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader):
start_time = time.time()
label = target[..., :self.args['output_dim']]
output = self.model(data).to(self.args['device'])
@ -64,6 +68,8 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, mode)
total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
@ -77,6 +83,8 @@ class Trainer:
avg_loss = total_loss / len(dataloader)
self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 记录内存
self.stats.record_memory_usage()
return avg_loss
def train_epoch(self, epoch):
@ -93,6 +101,7 @@ class Trainer:
best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0
self.stats.start_training()
self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch)
@ -125,6 +134,14 @@ class Trainer:
torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
# 输出统计与参数
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model):

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
import torch
from lib.logger import get_logger
from lib.loss_function import all_metrics
from lib.training_stats import TrainingStats
class Trainer:
@ -35,6 +36,8 @@ class Trainer:
os.makedirs(args['log_dir'], exist_ok=True)
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
self.logger.info(f"Experiment log path in: {args['log_dir']}")
# Stats tracker
self.stats = TrainingStats(device=args['device'])
self.times = times.to(self.device, dtype=torch.float)
self.w = w
@ -52,6 +55,7 @@ class Trainer:
with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, batch in enumerate(dataloader):
start_time = time.time()
batch = tuple(b.to(self.device, dtype=torch.float) for b in batch)
*train_coeffs, target = batch
label = target[..., :self.args['output_dim']]
@ -69,6 +73,8 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, mode)
total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
@ -82,6 +88,8 @@ class Trainer:
avg_loss = total_loss / len(dataloader)
self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 记录内存
self.stats.record_memory_usage()
return avg_loss
def train_epoch(self, epoch):
return self._run_epoch(epoch, self.train_loader, 'train')
@ -97,6 +105,7 @@ class Trainer:
best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0
self.stats.start_training()
self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch)
@ -129,6 +138,14 @@ class Trainer:
torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
# 输出统计与参数
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model):

View File

@ -11,6 +11,8 @@ def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader
match args['model']['type']:
case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler, kwargs[0], None)
case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler, kwargs[0], None)
case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler)
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],