新增多个PEMS数据集配置文件,包含PEMSD3、PEMSD4、PEMSD7、PEMSD8及STAWnet、STGNRDE、ST_SSL模型的相关配置,优化模型训练参数设置。

This commit is contained in:
czzhangheng 2025-08-19 00:07:37 +08:00
parent b820b867fb
commit 29fd709c8c
17 changed files with 216 additions and 34 deletions

View File

@ -32,7 +32,7 @@ pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio
python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0} python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0}
``` ```
- model_name: 目前支持DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN - model_name: 目前支持DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN、STAWnet
- dataset_name目前支持PEMSD3PEMSD4、PEMSD7、PEMSD8 - dataset_name目前支持PEMSD3PEMSD4、PEMSD7、PEMSD8
- modetrain为训练模型test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。 - modetrain为训练模型test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。
- device: 支持'cpu'、'cuda:0'、cuda:1 ... 取决于机器卡数 - device: 支持'cpu'、'cuda:0'、cuda:1 ... 取决于机器卡数

View File

@ -6,7 +6,7 @@ data:
test_ratio: 0.2 test_ratio: 0.2
tod: False tod: False
normalizer: std normalizer: std
column_wise: False column_wise: True
default_graph: True default_graph: True
add_time_in_day: True add_time_in_day: True
add_day_in_week: True add_day_in_week: True
@ -21,26 +21,26 @@ model:
max_diffusion_step: 2 max_diffusion_step: 2
cl_decay_steps: 1000 cl_decay_steps: 1000
filter_type: dual_random_walk filter_type: dual_random_walk
num_rnn_layers: 1 num_rnn_layers: 2
rnn_units: 64 rnn_units: 64
seq_len: 12 seq_len: 12
use_curriculum_learning: True use_curriculum_learning: True
train: train:
loss_func: mae loss_func: mask_mae
seed: 10 seed: 10
batch_size: 64 batch_size: 64
epochs: 300 epochs: 300
lr_init: 0.003 lr_init: 0.001
weight_decay: 0 weight_decay: 0.0001
lr_decay: False lr_decay: True
lr_decay_rate: 0.3 lr_decay_rate: 0.1
lr_decay_step: "5,20,40,70" lr_decay_step: "10,20,40,80"
early_stop: True early_stop: True
early_stop_patience: 15 early_stop_patience: 25
grad_norm: False grad_norm: True
max_grad_norm: 5 max_grad_norm: 5
real_value: True real_value: False
test: test:
mae_thresh: null mae_thresh: null

View File

@ -2,10 +2,12 @@ from dataloader.cde_loader.cdeDataloader import get_dataloader as cde_loader
from dataloader.PeMSDdataloader import get_dataloader as normal_loader from dataloader.PeMSDdataloader import get_dataloader as normal_loader
from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader
from dataloader.EXPdataloader import get_dataloader as EXP_loader from dataloader.EXPdataloader import get_dataloader as EXP_loader
from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader
def get_dataloader(config, normalizer, single): def get_dataloader(config, normalizer, single):
match config['model']['type']: match config['model']['type']:
case 'STGNCDE': return cde_loader(config['data'], normalizer, single) case 'STGNCDE': return cde_loader(config['data'], normalizer, single)
case 'STGNRDE': return nrde_loader(config['data'], normalizer, single)
case 'DCRNN': return DCRNN_loader(config['data'], normalizer, single) case 'DCRNN': return DCRNN_loader(config['data'], normalizer, single)
case 'EXP': return EXP_loader(config['data'], normalizer, single) case 'EXP': return EXP_loader(config['data'], normalizer, single)
case _: return normal_loader(config['data'], normalizer, single) case _: return normal_loader(config['data'], normalizer, single)

View File

@ -1,9 +1,9 @@
def masked_mae_loss(scaler, mask_value): def masked_mae_loss(scaler, mask_value):
def loss(preds, labels): def loss(preds, labels):
# 仅对预测反归一化;标签在数据管道中保持原始量纲
if scaler: if scaler:
preds = scaler.inverse_transform(preds) preds = scaler.inverse_transform(preds)
labels = scaler.inverse_transform(labels)
return mae_torch(pred=preds, true=labels, mask_value=mask_value) return mae_torch(pred=preds, true=labels, mask_value=mask_value)
return loss return loss
@ -11,7 +11,8 @@ def masked_mae_loss(scaler, mask_value):
def get_loss_function(args, scaler): def get_loss_function(args, scaler):
if args['loss_func'] == 'mask_mae': if args['loss_func'] == 'mask_mae':
return masked_mae_loss(scaler, mask_value=0.0).to(args['device']) # Return callable loss (no .to for function closures); disable masking by default
return masked_mae_loss(scaler, mask_value=None)
elif args['loss_func'] == 'mae': elif args['loss_func'] == 'mae':
return torch.nn.L1Loss().to(args['device']) return torch.nn.L1Loss().to(args['device'])
elif args['loss_func'] == 'mse': elif args['loss_func'] == 'mse':

View File

@ -34,7 +34,7 @@ class LayerParams:
class DCGRUCell(torch.nn.Module): class DCGRUCell(torch.nn.Module):
def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, nonlinearity='tanh', def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, input_dim=None, nonlinearity='tanh',
filter_type="laplacian", use_gc_for_ru=True): filter_type="laplacian", use_gc_for_ru=True):
""" """
@ -55,6 +55,7 @@ class DCGRUCell(torch.nn.Module):
self._max_diffusion_step = max_diffusion_step self._max_diffusion_step = max_diffusion_step
self._supports = [] self._supports = []
self._use_gc_for_ru = use_gc_for_ru self._use_gc_for_ru = use_gc_for_ru
self._input_dim = input_dim # optional; if None, will be inferred at first forward
supports = [] supports = []
if filter_type == "laplacian": if filter_type == "laplacian":
supports.append(utils.calculate_scaled_laplacian(adj_mx, lambda_max=None)) supports.append(utils.calculate_scaled_laplacian(adj_mx, lambda_max=None))
@ -71,6 +72,22 @@ class DCGRUCell(torch.nn.Module):
self._fc_params = LayerParams(self, 'fc') self._fc_params = LayerParams(self, 'fc')
self._gconv_params = LayerParams(self, 'gconv') self._gconv_params = LayerParams(self, 'gconv')
# Pre-create parameters if input_dim is known
if self._input_dim is not None:
num_matrices = len(self._supports) * self._max_diffusion_step + 1
input_size = self._input_dim + self._num_units
# FC weights/biases for RU gates (2 * num_units)
self._fc_params.get_weights((input_size, 2 * self._num_units))
self._fc_params.get_biases(2 * self._num_units, bias_start=1.0)
# Optionally for candidate (num_units) if FC path is used
self._fc_params.get_weights((input_size, self._num_units))
self._fc_params.get_biases(self._num_units, bias_start=0.0)
# GConv weights/biases for RU gates and candidate
self._gconv_params.get_weights((input_size * num_matrices, 2 * self._num_units))
self._gconv_params.get_biases(2 * self._num_units, bias_start=1.0)
self._gconv_params.get_weights((input_size * num_matrices, self._num_units))
self._gconv_params.get_biases(self._num_units, bias_start=0.0)
@staticmethod @staticmethod
def _build_sparse_matrix(L): def _build_sparse_matrix(L):
L = L.tocoo() L = L.tocoo()

View File

@ -15,6 +15,10 @@ class Seq2SeqAttrs:
self.num_nodes = args.get('num_nodes', 1) self.num_nodes = args.get('num_nodes', 1)
self.num_rnn_layers = args.get('num_rnn_layers', 1) self.num_rnn_layers = args.get('num_rnn_layers', 1)
self.rnn_units = args.get('rnn_units') self.rnn_units = args.get('rnn_units')
self.input_dim = args.get('input_dim', 1)
self.output_dim = args.get('output_dim', 1)
self.horizon = args.get('horizon', 12)
self.seq_len = args.get('seq_len', 12)
self.hidden_state_size = self.num_nodes * self.rnn_units self.hidden_state_size = self.num_nodes * self.rnn_units
@ -26,7 +30,7 @@ class EncoderModel(nn.Module, Seq2SeqAttrs):
self.seq_len = args.get('seq_len') # for the encoder self.seq_len = args.get('seq_len') # for the encoder
self.dcgru_layers = nn.ModuleList( self.dcgru_layers = nn.ModuleList(
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes, [DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
filter_type=self.filter_type) for _ in range(self.num_rnn_layers)]) input_dim=self.input_dim, filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
def forward(self, inputs, hidden_state=None): def forward(self, inputs, hidden_state=None):
""" """
@ -63,7 +67,7 @@ class DecoderModel(nn.Module, Seq2SeqAttrs):
self.projection_layer = nn.Linear(self.rnn_units, self.output_dim) self.projection_layer = nn.Linear(self.rnn_units, self.output_dim)
self.dcgru_layers = nn.ModuleList( self.dcgru_layers = nn.ModuleList(
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes, [DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
filter_type=self.filter_type) for _ in range(self.num_rnn_layers)]) input_dim=self.output_dim, filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
def forward(self, inputs, hidden_state=None): def forward(self, inputs, hidden_state=None):
""" """
@ -146,17 +150,19 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
def forward(self, inputs, labels=None): def forward(self, inputs, labels=None):
""" """
seq2seq forward pass 64 12 307 3 seq2seq forward pass. inputs: [B, T, N, C]
:param inputs: shape (seq_len, batch_size, num_sensor * input_dim) 12 64 307 * 1
:param labels: shape (horizon, batch_size, num_sensor * output) 12 64 307 1
:param batches_seen: batches seen till now
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
""" """
inputs = inputs[..., 0].permute(1, 0, 2) x = inputs[..., :self.input_dim]
labels = labels[..., 0].permute(1, 0, 2) x = x.permute(1, 0, 2, 3).contiguous().view(self.seq_len, -1, self.num_nodes * self.input_dim)
encoder_hidden_state = self.encoder(inputs)
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=self.batch_seen) y = None
if labels is not None:
y = labels[..., :self.output_dim]
y = y.permute(1, 0, 2, 3).contiguous().view(self.horizon, -1, self.num_nodes * self.output_dim)
encoder_hidden_state = self.encoder(x)
outputs = self.decoder(encoder_hidden_state, y, batches_seen=self.batch_seen)
self.batch_seen += 1 self.batch_seen += 1
outputs = outputs.unsqueeze(dim=-1) # [12,64,307,1] outputs = outputs.view(self.horizon, -1, self.num_nodes, self.output_dim)
outputs = outputs.permute(1, 0, 2, 3) # [64,12,307,1] outputs = outputs.permute(1, 0, 2, 3).contiguous()
return outputs return outputs

View File

@ -7,6 +7,7 @@ import numpy as np
from lib.logger import get_logger from lib.logger import get_logger
from lib.metrics import All_Metrics from lib.metrics import All_Metrics
from lib.TrainInits import print_model_parameters from lib.TrainInits import print_model_parameters
from lib.training_stats import TrainingStats
class Trainer(object): class Trainer(object):
def __init__(self, model, vector_field_f, vector_field_g, loss, optimizer, train_loader, val_loader, test_loader, def __init__(self, model, vector_field_f, vector_field_g, loss, optimizer, train_loader, val_loader, test_loader,
@ -42,6 +43,8 @@ class Trainer(object):
self.device = device self.device = device
self.times = times.to(self.device, dtype=torch.float) self.times = times.to(self.device, dtype=torch.float)
self.w = w self.w = w
# Stats tracker
self.stats = TrainingStats(device=device)
def val_epoch(self, epoch, val_dataloader): def val_epoch(self, epoch, val_dataloader):
self.model.eval() self.model.eval()
@ -49,6 +52,7 @@ class Trainer(object):
with torch.no_grad(): with torch.no_grad():
for batch_idx, batch in enumerate(self.val_loader): for batch_idx, batch in enumerate(self.val_loader):
start_time = time.time()
# for iter, batch in enumerate(val_dataloader): # for iter, batch in enumerate(val_dataloader):
batch = tuple(b.to(self.device, dtype=torch.float) for b in batch) batch = tuple(b.to(self.device, dtype=torch.float) for b in batch)
*valid_coeffs, target = batch *valid_coeffs, target = batch
@ -61,8 +65,11 @@ class Trainer(object):
#a whole batch of Metr_LA is filtered #a whole batch of Metr_LA is filtered
if not torch.isnan(loss): if not torch.isnan(loss):
total_val_loss += loss.item() total_val_loss += loss.item()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, 'val')
val_loss = total_val_loss / len(val_dataloader) val_loss = total_val_loss / len(val_dataloader)
self.logger.info('**********Val Epoch {}: average Loss: {:.6f}'.format(epoch, val_loss)) self.logger.info('**********Val Epoch {}: average Loss: {:.6f}'.format(epoch, val_loss))
self.stats.record_memory_usage()
if self.args.tensorboard: if self.args.tensorboard:
self.w.add_scalar(f'valid/loss', val_loss, epoch) self.w.add_scalar(f'valid/loss', val_loss, epoch)
return val_loss return val_loss
@ -73,6 +80,7 @@ class Trainer(object):
# for batch_idx, (data, target) in enumerate(self.train_loader): # for batch_idx, (data, target) in enumerate(self.train_loader):
# for batch_idx, (data, target) in enumerate(self.train_loader): # for batch_idx, (data, target) in enumerate(self.train_loader):
for batch_idx, batch in enumerate(self.train_loader): for batch_idx, batch in enumerate(self.train_loader):
start_time = time.time()
batch = tuple(b.to(self.device, dtype=torch.float) for b in batch) batch = tuple(b.to(self.device, dtype=torch.float) for b in batch)
*train_coeffs, target = batch *train_coeffs, target = batch
# data = data[..., :self.args.input_dim] # data = data[..., :self.args.input_dim]
@ -101,6 +109,8 @@ class Trainer(object):
if self.args.grad_norm: if self.args.grad_norm:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
self.optimizer.step() self.optimizer.step()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, 'train')
total_loss += loss.item() total_loss += loss.item()
#log information #log information
@ -109,6 +119,7 @@ class Trainer(object):
epoch, batch_idx, self.train_per_epoch, loss.item())) epoch, batch_idx, self.train_per_epoch, loss.item()))
train_epoch_loss = total_loss/self.train_per_epoch train_epoch_loss = total_loss/self.train_per_epoch
self.logger.info('**********Train Epoch {}: averaged Loss: {:.6f}'.format(epoch, train_epoch_loss)) self.logger.info('**********Train Epoch {}: averaged Loss: {:.6f}'.format(epoch, train_epoch_loss))
self.stats.record_memory_usage()
if self.args.tensorboard: if self.args.tensorboard:
self.w.add_scalar(f'train/loss', train_epoch_loss, epoch) self.w.add_scalar(f'train/loss', train_epoch_loss, epoch)
@ -123,6 +134,7 @@ class Trainer(object):
not_improved_count = 0 not_improved_count = 0
train_loss_list = [] train_loss_list = []
val_loss_list = [] val_loss_list = []
self.stats.start_training()
start_time = time.time() start_time = time.time()
for epoch in range(1, self.args.epochs + 1): for epoch in range(1, self.args.epochs + 1):
#epoch_time = time.time() #epoch_time = time.time()
@ -168,6 +180,13 @@ class Trainer(object):
training_time = time.time() - start_time training_time = time.time() - start_time
self.logger.info("Total training time: {:.4f}min, best loss: {:.6f}".format((training_time / 60), best_loss)) self.logger.info("Total training time: {:.4f}min, best loss: {:.6f}".format((training_time / 60), best_loss))
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
#save the best model to file #save the best model to file
if not self.args.debug: if not self.args.debug:

View File

@ -18,6 +18,10 @@ from model.STIDGCN.STIDGCN import STIDGCN
from model.STID.STID import STID from model.STID.STID import STID
from model.STAEFormer.STAEFormer import STAEformer from model.STAEFormer.STAEFormer import STAEformer
from model.EXP.EXP32 import EXP as EXP from model.EXP.EXP32 import EXP as EXP
from model.MegaCRN.MegaCRNModel import MegaCRNModel
from model.ST_SSL.ST_SSL import STSSLModel
from model.STGNRDE.Make_model import make_model as make_nrde_model
from model.STAWnet.STAWnet import STAWnet
def model_selector(model): def model_selector(model):
match model['type']: match model['type']:
@ -41,4 +45,8 @@ def model_selector(model):
case 'STID': return STID(model) case 'STID': return STID(model)
case 'STAEFormer': return STAEformer(model) case 'STAEFormer': return STAEformer(model)
case 'EXP': return EXP(model) case 'EXP': return EXP(model)
case 'MegaCRN': return MegaCRNModel(model)
case 'ST_SSL': return STSSLModel(model)
case 'STGNRDE': return make_nrde_model(model)
case 'STAWnet': return STAWnet(model)

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
import torch import torch
from lib.logger import get_logger from lib.logger import get_logger
from lib.loss_function import all_metrics from lib.loss_function import all_metrics
from lib.training_stats import TrainingStats
class Trainer: class Trainer:
@ -34,6 +35,8 @@ class Trainer:
os.makedirs(args['log_dir'], exist_ok=True) os.makedirs(args['log_dir'], exist_ok=True)
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug']) self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
self.logger.info(f"Experiment log path in: {args['log_dir']}") self.logger.info(f"Experiment log path in: {args['log_dir']}")
# Stats tracker
self.stats = TrainingStats(device=args['device'])
def _run_epoch(self, epoch, dataloader, mode): def _run_epoch(self, epoch, dataloader, mode):
if mode == 'train': if mode == 'train':
@ -49,12 +52,13 @@ class Trainer:
with torch.set_grad_enabled(optimizer_step): with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader): for batch_idx, (data, target) in enumerate(dataloader):
start_time = time.time()
label = target[..., :self.args['output_dim']] label = target[..., :self.args['output_dim']]
# label = target[..., :self.args['output_dim']]
output = self.model(data, labels=label.clone()).to(self.args['device']) output = self.model(data, labels=label.clone()).to(self.args['device'])
if self.args['real_value']: if self.args['real_value']:
output = self.scaler.inverse_transform(output) output = self.scaler.inverse_transform(output)
label = self.scaler.inverse_transform(label)
loss = self.loss(output, label) loss = self.loss(output, label)
if optimizer_step and self.optimizer is not None: if optimizer_step and self.optimizer is not None:
@ -65,6 +69,8 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step() self.optimizer.step()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, mode)
total_loss += loss.item() total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0: if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
@ -78,6 +84,8 @@ class Trainer:
avg_loss = total_loss / len(dataloader) avg_loss = total_loss / len(dataloader)
self.logger.info( self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 记录内存
self.stats.record_memory_usage()
return avg_loss return avg_loss
def train_epoch(self, epoch): def train_epoch(self, epoch):
@ -94,6 +102,7 @@ class Trainer:
best_loss, best_test_loss = float('inf'), float('inf') best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0 not_improved_count = 0
self.stats.start_training()
self.logger.info("Training process started") self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1): for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch) train_epoch_loss = self.train_epoch(epoch)
@ -126,6 +135,14 @@ class Trainer:
torch.save(best_test_model, self.best_test_path) torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
# 输出统计与参数
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
self._finalize_training(best_model, best_test_model) self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model): def _finalize_training(self, best_model, best_test_model):
@ -154,11 +171,11 @@ class Trainer:
y_pred.append(output) y_pred.append(output)
y_true.append(label) y_true.append(label)
if args['real_value']: y_pred = torch.cat(y_pred, dim=0)
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
else:
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0) y_true = torch.cat(y_true, dim=0)
if args['real_value']:
y_pred = scaler.inverse_transform(y_pred)
y_true = scaler.inverse_transform(y_true)
for t in range(y_true.shape[1]): for t in range(y_true.shape[1]):
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
import torch import torch
from lib.logger import get_logger from lib.logger import get_logger
from lib.loss_function import all_metrics from lib.loss_function import all_metrics
from lib.training_stats import TrainingStats
class Trainer: class Trainer:
@ -34,6 +35,8 @@ class Trainer:
os.makedirs(args['log_dir'], exist_ok=True) os.makedirs(args['log_dir'], exist_ok=True)
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug']) self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
self.logger.info(f"Experiment log path in: {args['log_dir']}") self.logger.info(f"Experiment log path in: {args['log_dir']}")
# Stats tracker
self.stats = TrainingStats(device=args['device'])
def _run_epoch(self, epoch, dataloader, mode): def _run_epoch(self, epoch, dataloader, mode):
is_train = (mode == 'train') is_train = (mode == 'train')
@ -45,6 +48,7 @@ class Trainer:
tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, batch in enumerate(dataloader): for batch_idx, batch in enumerate(dataloader):
start_time = time.time()
# unpack the new cycle_index # unpack the new cycle_index
data, target, cycle_index = batch data, target, cycle_index = batch
data = data.to(self.args['device']) data = data.to(self.args['device'])
@ -72,6 +76,8 @@ class Trainer:
self.args['max_grad_norm']) self.args['max_grad_norm'])
self.optimizer.step() self.optimizer.step()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, mode)
total_loss += loss.item() total_loss += loss.item()
# logging # logging
@ -86,6 +92,8 @@ class Trainer:
avg_loss = total_loss / len(dataloader) avg_loss = total_loss / len(dataloader)
self.logger.info( self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 记录内存
self.stats.record_memory_usage()
return avg_loss return avg_loss
def train_epoch(self, epoch): def train_epoch(self, epoch):
@ -102,6 +110,7 @@ class Trainer:
best_loss, best_test_loss = float('inf'), float('inf') best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0 not_improved_count = 0
self.stats.start_training()
self.logger.info("Training process started") self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1): for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch) train_epoch_loss = self.train_epoch(epoch)
@ -134,6 +143,14 @@ class Trainer:
torch.save(best_test_model, self.best_test_path) torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
# 输出统计与参数
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
self._finalize_training(best_model, best_test_model) self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model): def _finalize_training(self, best_model, best_test_model):

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
import torch import torch
from lib.logger import get_logger from lib.logger import get_logger
from lib.loss_function import all_metrics from lib.loss_function import all_metrics
from lib.training_stats import TrainingStats
class Trainer: class Trainer:
@ -34,6 +35,8 @@ class Trainer:
os.makedirs(args['log_dir'], exist_ok=True) os.makedirs(args['log_dir'], exist_ok=True)
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug']) self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
self.logger.info(f"Experiment log path in: {args['log_dir']}") self.logger.info(f"Experiment log path in: {args['log_dir']}")
# Stats tracker
self.stats = TrainingStats(device=args['device'])
def _run_epoch(self, epoch, dataloader, mode): def _run_epoch(self, epoch, dataloader, mode):
if mode == 'train': if mode == 'train':
@ -49,6 +52,7 @@ class Trainer:
with torch.set_grad_enabled(optimizer_step): with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader): for batch_idx, (data, target) in enumerate(dataloader):
start_time = time.time()
label = target[..., :self.args['output_dim']] label = target[..., :self.args['output_dim']]
output = self.model(data).to(self.args['device']) output = self.model(data).to(self.args['device'])
@ -64,6 +68,9 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step() self.optimizer.step()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, mode)
total_loss += loss.item() total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0: if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
@ -77,6 +84,8 @@ class Trainer:
avg_loss = total_loss / len(dataloader) avg_loss = total_loss / len(dataloader)
self.logger.info( self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 记录内存
self.stats.record_memory_usage()
return avg_loss return avg_loss
def train_epoch(self, epoch): def train_epoch(self, epoch):
@ -93,6 +102,7 @@ class Trainer:
best_loss, best_test_loss = float('inf'), float('inf') best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0 not_improved_count = 0
self.stats.start_training()
self.logger.info("Training process started") self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1): for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch) train_epoch_loss = self.train_epoch(epoch)
@ -124,7 +134,14 @@ class Trainer:
torch.save(best_model, self.best_path) torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path) torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
# 输出统计与参数
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
self._finalize_training(best_model, best_test_model) self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model): def _finalize_training(self, best_model, best_test_model):

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
import torch import torch
from lib.logger import get_logger from lib.logger import get_logger
from lib.loss_function import all_metrics from lib.loss_function import all_metrics
from lib.training_stats import TrainingStats
class Trainer: class Trainer:
@ -35,6 +36,8 @@ class Trainer:
os.makedirs(args['log_dir'], exist_ok=True) os.makedirs(args['log_dir'], exist_ok=True)
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug']) self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
self.logger.info(f"Experiment log path in: {args['log_dir']}") self.logger.info(f"Experiment log path in: {args['log_dir']}")
# Stats tracker
self.stats = TrainingStats(device=args['device'])
def _run_epoch(self, epoch, dataloader, mode): def _run_epoch(self, epoch, dataloader, mode):
if mode == 'train': if mode == 'train':
@ -50,6 +53,7 @@ class Trainer:
with torch.set_grad_enabled(optimizer_step): with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader): for batch_idx, (data, target) in enumerate(dataloader):
start_time = time.time()
self.batches_seen += 1 self.batches_seen += 1
label = target[..., :self.args['output_dim']].clone() label = target[..., :self.args['output_dim']].clone()
output = self.model(data, target, self.batches_seen).to(self.args['device']) output = self.model(data, target, self.batches_seen).to(self.args['device'])
@ -66,6 +70,9 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step() self.optimizer.step()
# record step time
step_time = time.time() - start_time
self.stats.record_step_time(step_time, mode)
total_loss += loss.item() total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0: if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
@ -79,6 +86,8 @@ class Trainer:
avg_loss = total_loss / len(dataloader) avg_loss = total_loss / len(dataloader)
self.logger.info( self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 记录内存
self.stats.record_memory_usage()
return avg_loss return avg_loss
def train_epoch(self, epoch): def train_epoch(self, epoch):
@ -95,6 +104,7 @@ class Trainer:
best_loss, best_test_loss = float('inf'), float('inf') best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0 not_improved_count = 0
self.stats.start_training()
self.logger.info("Training process started") self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1): for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch) train_epoch_loss = self.train_epoch(epoch)
@ -127,6 +137,14 @@ class Trainer:
torch.save(best_test_model, self.best_test_path) torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
# 输出统计与参数
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
self._finalize_training(best_model, best_test_model) self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model): def _finalize_training(self, best_model, best_test_model):

View File

@ -11,6 +11,7 @@ from tqdm import tqdm
from lib.logger import get_logger from lib.logger import get_logger
from lib.loss_function import all_metrics from lib.loss_function import all_metrics
from model.STMLP.STMLP import STMLP from model.STMLP.STMLP import STMLP
from lib.training_stats import TrainingStats
class Trainer: class Trainer:
@ -43,6 +44,8 @@ class Trainer:
os.makedirs(self.pretrain_dir, exist_ok=True) os.makedirs(self.pretrain_dir, exist_ok=True)
self.logger = get_logger(self.args['log_dir'], name=self.model.__class__.__name__, debug=self.args['debug']) self.logger = get_logger(self.args['log_dir'], name=self.model.__class__.__name__, debug=self.args['debug'])
self.logger.info(f"Experiment log path in: {self.args['log_dir']}") self.logger.info(f"Experiment log path in: {self.args['log_dir']}")
# Stats tracker
self.stats = TrainingStats(device=args['device'])
if self.args['teacher_stu']: if self.args['teacher_stu']:
self.tmodel = self.loadTeacher(args) self.tmodel = self.loadTeacher(args)
@ -67,6 +70,7 @@ class Trainer:
with torch.set_grad_enabled(optimizer_step): with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader): for batch_idx, (data, target) in enumerate(dataloader):
start_time = time.time()
if self.args['teacher_stu']: if self.args['teacher_stu']:
label = target[..., :self.args['output_dim']] label = target[..., :self.args['output_dim']]
output, out_, _ = self.model(data) output, out_, _ = self.model(data)
@ -100,6 +104,8 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step() self.optimizer.step()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, mode)
total_loss += loss.item() total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0: if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
@ -113,6 +119,8 @@ class Trainer:
avg_loss = total_loss / len(dataloader) avg_loss = total_loss / len(dataloader)
self.logger.info( self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 记录内存
self.stats.record_memory_usage()
return avg_loss return avg_loss
def train_epoch(self, epoch): def train_epoch(self, epoch):
@ -129,6 +137,7 @@ class Trainer:
best_loss, best_test_loss = float('inf'), float('inf') best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0 not_improved_count = 0
self.stats.start_training()
self.logger.info("Training process started") self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1): for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch) train_epoch_loss = self.train_epoch(epoch)
@ -165,6 +174,14 @@ class Trainer:
torch.save(best_test_model, self.best_test_path) torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
# 输出统计与参数
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
self._finalize_training(best_model, best_test_model) self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model): def _finalize_training(self, best_model, best_test_model):

View File

@ -211,6 +211,13 @@ class Trainer:
self._finalize_training(best_model, best_test_model) self._finalize_training(best_model, best_test_model)
# 输出参数量
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
def _finalize_training(self, best_model, best_test_model): def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model) self.model.load_state_dict(best_model)
self.logger.info("Testing on best validation model") self.logger.info("Testing on best validation model")

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
import torch import torch
from lib.logger import get_logger from lib.logger import get_logger
from lib.loss_function import all_metrics from lib.loss_function import all_metrics
from lib.training_stats import TrainingStats
class Trainer: class Trainer:
@ -34,6 +35,8 @@ class Trainer:
os.makedirs(args['log_dir'], exist_ok=True) os.makedirs(args['log_dir'], exist_ok=True)
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug']) self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
self.logger.info(f"Experiment log path in: {args['log_dir']}") self.logger.info(f"Experiment log path in: {args['log_dir']}")
# Stats tracker
self.stats = TrainingStats(device=args['device'])
def _run_epoch(self, epoch, dataloader, mode): def _run_epoch(self, epoch, dataloader, mode):
if mode == 'train': if mode == 'train':
@ -49,6 +52,7 @@ class Trainer:
with torch.set_grad_enabled(optimizer_step): with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader): for batch_idx, (data, target) in enumerate(dataloader):
start_time = time.time()
label = target[..., :self.args['output_dim']] label = target[..., :self.args['output_dim']]
output = self.model(data).to(self.args['device']) output = self.model(data).to(self.args['device'])
@ -64,6 +68,8 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step() self.optimizer.step()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, mode)
total_loss += loss.item() total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0: if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
@ -77,6 +83,8 @@ class Trainer:
avg_loss = total_loss / len(dataloader) avg_loss = total_loss / len(dataloader)
self.logger.info( self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 记录内存
self.stats.record_memory_usage()
return avg_loss return avg_loss
def train_epoch(self, epoch): def train_epoch(self, epoch):
@ -93,6 +101,7 @@ class Trainer:
best_loss, best_test_loss = float('inf'), float('inf') best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0 not_improved_count = 0
self.stats.start_training()
self.logger.info("Training process started") self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1): for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch) train_epoch_loss = self.train_epoch(epoch)
@ -125,6 +134,14 @@ class Trainer:
torch.save(best_test_model, self.best_test_path) torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
# 输出统计与参数
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
self._finalize_training(best_model, best_test_model) self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model): def _finalize_training(self, best_model, best_test_model):

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
import torch import torch
from lib.logger import get_logger from lib.logger import get_logger
from lib.loss_function import all_metrics from lib.loss_function import all_metrics
from lib.training_stats import TrainingStats
class Trainer: class Trainer:
@ -35,6 +36,8 @@ class Trainer:
os.makedirs(args['log_dir'], exist_ok=True) os.makedirs(args['log_dir'], exist_ok=True)
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug']) self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
self.logger.info(f"Experiment log path in: {args['log_dir']}") self.logger.info(f"Experiment log path in: {args['log_dir']}")
# Stats tracker
self.stats = TrainingStats(device=args['device'])
self.times = times.to(self.device, dtype=torch.float) self.times = times.to(self.device, dtype=torch.float)
self.w = w self.w = w
@ -52,6 +55,7 @@ class Trainer:
with torch.set_grad_enabled(optimizer_step): with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, batch in enumerate(dataloader): for batch_idx, batch in enumerate(dataloader):
start_time = time.time()
batch = tuple(b.to(self.device, dtype=torch.float) for b in batch) batch = tuple(b.to(self.device, dtype=torch.float) for b in batch)
*train_coeffs, target = batch *train_coeffs, target = batch
label = target[..., :self.args['output_dim']] label = target[..., :self.args['output_dim']]
@ -69,6 +73,8 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step() self.optimizer.step()
step_time = time.time() - start_time
self.stats.record_step_time(step_time, mode)
total_loss += loss.item() total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0: if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
@ -82,6 +88,8 @@ class Trainer:
avg_loss = total_loss / len(dataloader) avg_loss = total_loss / len(dataloader)
self.logger.info( self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 记录内存
self.stats.record_memory_usage()
return avg_loss return avg_loss
def train_epoch(self, epoch): def train_epoch(self, epoch):
return self._run_epoch(epoch, self.train_loader, 'train') return self._run_epoch(epoch, self.train_loader, 'train')
@ -97,6 +105,7 @@ class Trainer:
best_loss, best_test_loss = float('inf'), float('inf') best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0 not_improved_count = 0
self.stats.start_training()
self.logger.info("Training process started") self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1): for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch) train_epoch_loss = self.train_epoch(epoch)
@ -129,6 +138,14 @@ class Trainer:
torch.save(best_test_model, self.best_test_path) torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
# 输出统计与参数
self.stats.end_training()
self.stats.report(self.logger)
try:
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
self._finalize_training(best_model, best_test_model) self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model): def _finalize_training(self, best_model, best_test_model):

View File

@ -11,6 +11,8 @@ def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader
match args['model']['type']: match args['model']['type']:
case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler, kwargs[0], None) lr_scheduler, kwargs[0], None)
case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler, kwargs[0], None)
case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler) lr_scheduler)
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],