diff --git a/README.md b/README.md index 1a4ea28..4c4d154 100755 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0} ``` -- model_name: 目前支持:DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN +- model_name: 目前支持:DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN、STAWnet - dataset_name目前支持:PEMSD3,PEMSD4、PEMSD7、PEMSD8 - mode:train为训练模型,test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。 - device: 支持'cpu'、'cuda:0'、‘cuda:1’ ... 取决于机器卡数 diff --git a/config/DCRNN/PEMSD4.yaml b/config/DCRNN/PEMSD4.yaml index 60ba0e5..8758786 100755 --- a/config/DCRNN/PEMSD4.yaml +++ b/config/DCRNN/PEMSD4.yaml @@ -6,7 +6,7 @@ data: test_ratio: 0.2 tod: False normalizer: std - column_wise: False + column_wise: True default_graph: True add_time_in_day: True add_day_in_week: True @@ -21,26 +21,26 @@ model: max_diffusion_step: 2 cl_decay_steps: 1000 filter_type: dual_random_walk - num_rnn_layers: 1 + num_rnn_layers: 2 rnn_units: 64 seq_len: 12 use_curriculum_learning: True train: - loss_func: mae + loss_func: mask_mae seed: 10 batch_size: 64 epochs: 300 - lr_init: 0.003 - weight_decay: 0 - lr_decay: False - lr_decay_rate: 0.3 - lr_decay_step: "5,20,40,70" + lr_init: 0.001 + weight_decay: 0.0001 + lr_decay: True + lr_decay_rate: 0.1 + lr_decay_step: "10,20,40,80" early_stop: True - early_stop_patience: 15 - grad_norm: False + early_stop_patience: 25 + grad_norm: True max_grad_norm: 5 - real_value: True + real_value: False test: mae_thresh: null diff --git a/dataloader/loader_selector.py b/dataloader/loader_selector.py index eea8d60..b7f697c 100755 --- a/dataloader/loader_selector.py +++ b/dataloader/loader_selector.py @@ -2,10 +2,12 @@ from dataloader.cde_loader.cdeDataloader import get_dataloader as cde_loader from dataloader.PeMSDdataloader import get_dataloader as normal_loader from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader from dataloader.EXPdataloader import get_dataloader as EXP_loader +from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader def get_dataloader(config, normalizer, single): match config['model']['type']: case 'STGNCDE': return cde_loader(config['data'], normalizer, single) + case 'STGNRDE': return nrde_loader(config['data'], normalizer, single) case 'DCRNN': return DCRNN_loader(config['data'], normalizer, single) case 'EXP': return EXP_loader(config['data'], normalizer, single) case _: return normal_loader(config['data'], normalizer, single) diff --git a/lib/loss_function.py b/lib/loss_function.py index 6645e11..ad8caf4 100755 --- a/lib/loss_function.py +++ b/lib/loss_function.py @@ -1,9 +1,9 @@ def masked_mae_loss(scaler, mask_value): def loss(preds, labels): + # 仅对预测反归一化;标签在数据管道中保持原始量纲 if scaler: preds = scaler.inverse_transform(preds) - labels = scaler.inverse_transform(labels) return mae_torch(pred=preds, true=labels, mask_value=mask_value) return loss @@ -11,7 +11,8 @@ def masked_mae_loss(scaler, mask_value): def get_loss_function(args, scaler): if args['loss_func'] == 'mask_mae': - return masked_mae_loss(scaler, mask_value=0.0).to(args['device']) + # Return callable loss (no .to for function closures); disable masking by default + return masked_mae_loss(scaler, mask_value=None) elif args['loss_func'] == 'mae': return torch.nn.L1Loss().to(args['device']) elif args['loss_func'] == 'mse': diff --git a/model/DCRNN/dcrnn_cell.py b/model/DCRNN/dcrnn_cell.py index c425069..fabe400 100755 --- a/model/DCRNN/dcrnn_cell.py +++ b/model/DCRNN/dcrnn_cell.py @@ -34,7 +34,7 @@ class LayerParams: class DCGRUCell(torch.nn.Module): - def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, nonlinearity='tanh', + def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, input_dim=None, nonlinearity='tanh', filter_type="laplacian", use_gc_for_ru=True): """ @@ -55,6 +55,7 @@ class DCGRUCell(torch.nn.Module): self._max_diffusion_step = max_diffusion_step self._supports = [] self._use_gc_for_ru = use_gc_for_ru + self._input_dim = input_dim # optional; if None, will be inferred at first forward supports = [] if filter_type == "laplacian": supports.append(utils.calculate_scaled_laplacian(adj_mx, lambda_max=None)) @@ -71,6 +72,22 @@ class DCGRUCell(torch.nn.Module): self._fc_params = LayerParams(self, 'fc') self._gconv_params = LayerParams(self, 'gconv') + # Pre-create parameters if input_dim is known + if self._input_dim is not None: + num_matrices = len(self._supports) * self._max_diffusion_step + 1 + input_size = self._input_dim + self._num_units + # FC weights/biases for RU gates (2 * num_units) + self._fc_params.get_weights((input_size, 2 * self._num_units)) + self._fc_params.get_biases(2 * self._num_units, bias_start=1.0) + # Optionally for candidate (num_units) if FC path is used + self._fc_params.get_weights((input_size, self._num_units)) + self._fc_params.get_biases(self._num_units, bias_start=0.0) + # GConv weights/biases for RU gates and candidate + self._gconv_params.get_weights((input_size * num_matrices, 2 * self._num_units)) + self._gconv_params.get_biases(2 * self._num_units, bias_start=1.0) + self._gconv_params.get_weights((input_size * num_matrices, self._num_units)) + self._gconv_params.get_biases(self._num_units, bias_start=0.0) + @staticmethod def _build_sparse_matrix(L): L = L.tocoo() diff --git a/model/DCRNN/dcrnn_model.py b/model/DCRNN/dcrnn_model.py index eb5803a..f7e7648 100755 --- a/model/DCRNN/dcrnn_model.py +++ b/model/DCRNN/dcrnn_model.py @@ -15,6 +15,10 @@ class Seq2SeqAttrs: self.num_nodes = args.get('num_nodes', 1) self.num_rnn_layers = args.get('num_rnn_layers', 1) self.rnn_units = args.get('rnn_units') + self.input_dim = args.get('input_dim', 1) + self.output_dim = args.get('output_dim', 1) + self.horizon = args.get('horizon', 12) + self.seq_len = args.get('seq_len', 12) self.hidden_state_size = self.num_nodes * self.rnn_units @@ -26,7 +30,7 @@ class EncoderModel(nn.Module, Seq2SeqAttrs): self.seq_len = args.get('seq_len') # for the encoder self.dcgru_layers = nn.ModuleList( [DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes, - filter_type=self.filter_type) for _ in range(self.num_rnn_layers)]) + input_dim=self.input_dim, filter_type=self.filter_type) for _ in range(self.num_rnn_layers)]) def forward(self, inputs, hidden_state=None): """ @@ -63,7 +67,7 @@ class DecoderModel(nn.Module, Seq2SeqAttrs): self.projection_layer = nn.Linear(self.rnn_units, self.output_dim) self.dcgru_layers = nn.ModuleList( [DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes, - filter_type=self.filter_type) for _ in range(self.num_rnn_layers)]) + input_dim=self.output_dim, filter_type=self.filter_type) for _ in range(self.num_rnn_layers)]) def forward(self, inputs, hidden_state=None): """ @@ -146,17 +150,19 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs): def forward(self, inputs, labels=None): """ - seq2seq forward pass 64 12 307 3 - :param inputs: shape (seq_len, batch_size, num_sensor * input_dim) 12 64 307 * 1 - :param labels: shape (horizon, batch_size, num_sensor * output) 12 64 307 1 - :param batches_seen: batches seen till now - :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) + seq2seq forward pass. inputs: [B, T, N, C] """ - inputs = inputs[..., 0].permute(1, 0, 2) - labels = labels[..., 0].permute(1, 0, 2) - encoder_hidden_state = self.encoder(inputs) - outputs = self.decoder(encoder_hidden_state, labels, batches_seen=self.batch_seen) + x = inputs[..., :self.input_dim] + x = x.permute(1, 0, 2, 3).contiguous().view(self.seq_len, -1, self.num_nodes * self.input_dim) + + y = None + if labels is not None: + y = labels[..., :self.output_dim] + y = y.permute(1, 0, 2, 3).contiguous().view(self.horizon, -1, self.num_nodes * self.output_dim) + + encoder_hidden_state = self.encoder(x) + outputs = self.decoder(encoder_hidden_state, y, batches_seen=self.batch_seen) self.batch_seen += 1 - outputs = outputs.unsqueeze(dim=-1) # [12,64,307,1] - outputs = outputs.permute(1, 0, 2, 3) # [64,12,307,1] + outputs = outputs.view(self.horizon, -1, self.num_nodes, self.output_dim) + outputs = outputs.permute(1, 0, 2, 3).contiguous() return outputs diff --git a/model/STGNCDE/BasicTrainer_cde.py b/model/STGNCDE/BasicTrainer_cde.py index ef0c4f3..569bcd5 100755 --- a/model/STGNCDE/BasicTrainer_cde.py +++ b/model/STGNCDE/BasicTrainer_cde.py @@ -7,6 +7,7 @@ import numpy as np from lib.logger import get_logger from lib.metrics import All_Metrics from lib.TrainInits import print_model_parameters +from lib.training_stats import TrainingStats class Trainer(object): def __init__(self, model, vector_field_f, vector_field_g, loss, optimizer, train_loader, val_loader, test_loader, @@ -42,6 +43,8 @@ class Trainer(object): self.device = device self.times = times.to(self.device, dtype=torch.float) self.w = w + # Stats tracker + self.stats = TrainingStats(device=device) def val_epoch(self, epoch, val_dataloader): self.model.eval() @@ -49,6 +52,7 @@ class Trainer(object): with torch.no_grad(): for batch_idx, batch in enumerate(self.val_loader): + start_time = time.time() # for iter, batch in enumerate(val_dataloader): batch = tuple(b.to(self.device, dtype=torch.float) for b in batch) *valid_coeffs, target = batch @@ -61,8 +65,11 @@ class Trainer(object): #a whole batch of Metr_LA is filtered if not torch.isnan(loss): total_val_loss += loss.item() + step_time = time.time() - start_time + self.stats.record_step_time(step_time, 'val') val_loss = total_val_loss / len(val_dataloader) self.logger.info('**********Val Epoch {}: average Loss: {:.6f}'.format(epoch, val_loss)) + self.stats.record_memory_usage() if self.args.tensorboard: self.w.add_scalar(f'valid/loss', val_loss, epoch) return val_loss @@ -73,6 +80,7 @@ class Trainer(object): # for batch_idx, (data, target) in enumerate(self.train_loader): # for batch_idx, (data, target) in enumerate(self.train_loader): for batch_idx, batch in enumerate(self.train_loader): + start_time = time.time() batch = tuple(b.to(self.device, dtype=torch.float) for b in batch) *train_coeffs, target = batch # data = data[..., :self.args.input_dim] @@ -101,6 +109,8 @@ class Trainer(object): if self.args.grad_norm: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) self.optimizer.step() + step_time = time.time() - start_time + self.stats.record_step_time(step_time, 'train') total_loss += loss.item() #log information @@ -109,6 +119,7 @@ class Trainer(object): epoch, batch_idx, self.train_per_epoch, loss.item())) train_epoch_loss = total_loss/self.train_per_epoch self.logger.info('**********Train Epoch {}: averaged Loss: {:.6f}'.format(epoch, train_epoch_loss)) + self.stats.record_memory_usage() if self.args.tensorboard: self.w.add_scalar(f'train/loss', train_epoch_loss, epoch) @@ -123,6 +134,7 @@ class Trainer(object): not_improved_count = 0 train_loss_list = [] val_loss_list = [] + self.stats.start_training() start_time = time.time() for epoch in range(1, self.args.epochs + 1): #epoch_time = time.time() @@ -168,6 +180,13 @@ class Trainer(object): training_time = time.time() - start_time self.logger.info("Total training time: {:.4f}min, best loss: {:.6f}".format((training_time / 60), best_loss)) + self.stats.end_training() + self.stats.report(self.logger) + try: + total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + except Exception: + pass #save the best model to file if not self.args.debug: diff --git a/model/model_selector.py b/model/model_selector.py index 796c814..4fba6bc 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -18,6 +18,10 @@ from model.STIDGCN.STIDGCN import STIDGCN from model.STID.STID import STID from model.STAEFormer.STAEFormer import STAEformer from model.EXP.EXP32 import EXP as EXP +from model.MegaCRN.MegaCRNModel import MegaCRNModel +from model.ST_SSL.ST_SSL import STSSLModel +from model.STGNRDE.Make_model import make_model as make_nrde_model +from model.STAWnet.STAWnet import STAWnet def model_selector(model): match model['type']: @@ -41,4 +45,8 @@ def model_selector(model): case 'STID': return STID(model) case 'STAEFormer': return STAEformer(model) case 'EXP': return EXP(model) + case 'MegaCRN': return MegaCRNModel(model) + case 'ST_SSL': return STSSLModel(model) + case 'STGNRDE': return make_nrde_model(model) + case 'STAWnet': return STAWnet(model) diff --git a/trainer/DCRNN_Trainer.py b/trainer/DCRNN_Trainer.py index 97a8290..3e9aa56 100755 --- a/trainer/DCRNN_Trainer.py +++ b/trainer/DCRNN_Trainer.py @@ -7,6 +7,7 @@ from tqdm import tqdm import torch from lib.logger import get_logger from lib.loss_function import all_metrics +from lib.training_stats import TrainingStats class Trainer: @@ -34,6 +35,8 @@ class Trainer: os.makedirs(args['log_dir'], exist_ok=True) self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug']) self.logger.info(f"Experiment log path in: {args['log_dir']}") + # Stats tracker + self.stats = TrainingStats(device=args['device']) def _run_epoch(self, epoch, dataloader, mode): if mode == 'train': @@ -49,12 +52,13 @@ class Trainer: with torch.set_grad_enabled(optimizer_step): with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: for batch_idx, (data, target) in enumerate(dataloader): + start_time = time.time() label = target[..., :self.args['output_dim']] - # label = target[..., :self.args['output_dim']] output = self.model(data, labels=label.clone()).to(self.args['device']) if self.args['real_value']: output = self.scaler.inverse_transform(output) + label = self.scaler.inverse_transform(label) loss = self.loss(output, label) if optimizer_step and self.optimizer is not None: @@ -65,6 +69,8 @@ class Trainer: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) self.optimizer.step() + step_time = time.time() - start_time + self.stats.record_step_time(step_time, mode) total_loss += loss.item() if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0: @@ -78,6 +84,8 @@ class Trainer: avg_loss = total_loss / len(dataloader) self.logger.info( f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') + # 记录内存 + self.stats.record_memory_usage() return avg_loss def train_epoch(self, epoch): @@ -94,6 +102,7 @@ class Trainer: best_loss, best_test_loss = float('inf'), float('inf') not_improved_count = 0 + self.stats.start_training() self.logger.info("Training process started") for epoch in range(1, self.args['epochs'] + 1): train_epoch_loss = self.train_epoch(epoch) @@ -126,6 +135,14 @@ class Trainer: torch.save(best_test_model, self.best_test_path) self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") + # 输出统计与参数 + self.stats.end_training() + self.stats.report(self.logger) + try: + total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + except Exception: + pass self._finalize_training(best_model, best_test_model) def _finalize_training(self, best_model, best_test_model): @@ -154,11 +171,11 @@ class Trainer: y_pred.append(output) y_true.append(label) - if args['real_value']: - y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) - else: - y_pred = torch.cat(y_pred, dim=0) + y_pred = torch.cat(y_pred, dim=0) y_true = torch.cat(y_true, dim=0) + if args['real_value']: + y_pred = scaler.inverse_transform(y_pred) + y_true = scaler.inverse_transform(y_true) for t in range(y_true.shape[1]): mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], diff --git a/trainer/E32Trainer.py b/trainer/E32Trainer.py index 1a8d062..b1bce7c 100644 --- a/trainer/E32Trainer.py +++ b/trainer/E32Trainer.py @@ -7,6 +7,7 @@ from tqdm import tqdm import torch from lib.logger import get_logger from lib.loss_function import all_metrics +from lib.training_stats import TrainingStats class Trainer: @@ -34,6 +35,8 @@ class Trainer: os.makedirs(args['log_dir'], exist_ok=True) self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug']) self.logger.info(f"Experiment log path in: {args['log_dir']}") + # Stats tracker + self.stats = TrainingStats(device=args['device']) def _run_epoch(self, epoch, dataloader, mode): is_train = (mode == 'train') @@ -45,6 +48,7 @@ class Trainer: tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: for batch_idx, batch in enumerate(dataloader): + start_time = time.time() # unpack the new cycle_index data, target, cycle_index = batch data = data.to(self.args['device']) @@ -72,6 +76,8 @@ class Trainer: self.args['max_grad_norm']) self.optimizer.step() + step_time = time.time() - start_time + self.stats.record_step_time(step_time, mode) total_loss += loss.item() # logging @@ -86,6 +92,8 @@ class Trainer: avg_loss = total_loss / len(dataloader) self.logger.info( f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') + # 记录内存 + self.stats.record_memory_usage() return avg_loss def train_epoch(self, epoch): @@ -102,6 +110,7 @@ class Trainer: best_loss, best_test_loss = float('inf'), float('inf') not_improved_count = 0 + self.stats.start_training() self.logger.info("Training process started") for epoch in range(1, self.args['epochs'] + 1): train_epoch_loss = self.train_epoch(epoch) @@ -134,6 +143,14 @@ class Trainer: torch.save(best_test_model, self.best_test_path) self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") + # 输出统计与参数 + self.stats.end_training() + self.stats.report(self.logger) + try: + total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + except Exception: + pass self._finalize_training(best_model, best_test_model) def _finalize_training(self, best_model, best_test_model): diff --git a/trainer/EXP_trainer.py b/trainer/EXP_trainer.py index 5613870..80dc6c7 100755 --- a/trainer/EXP_trainer.py +++ b/trainer/EXP_trainer.py @@ -7,6 +7,7 @@ from tqdm import tqdm import torch from lib.logger import get_logger from lib.loss_function import all_metrics +from lib.training_stats import TrainingStats class Trainer: @@ -34,6 +35,8 @@ class Trainer: os.makedirs(args['log_dir'], exist_ok=True) self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug']) self.logger.info(f"Experiment log path in: {args['log_dir']}") + # Stats tracker + self.stats = TrainingStats(device=args['device']) def _run_epoch(self, epoch, dataloader, mode): if mode == 'train': @@ -49,6 +52,7 @@ class Trainer: with torch.set_grad_enabled(optimizer_step): with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: for batch_idx, (data, target) in enumerate(dataloader): + start_time = time.time() label = target[..., :self.args['output_dim']] output = self.model(data).to(self.args['device']) @@ -64,6 +68,9 @@ class Trainer: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) self.optimizer.step() + step_time = time.time() - start_time + self.stats.record_step_time(step_time, mode) + total_loss += loss.item() if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0: @@ -77,6 +84,8 @@ class Trainer: avg_loss = total_loss / len(dataloader) self.logger.info( f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') + # 记录内存 + self.stats.record_memory_usage() return avg_loss def train_epoch(self, epoch): @@ -93,6 +102,7 @@ class Trainer: best_loss, best_test_loss = float('inf'), float('inf') not_improved_count = 0 + self.stats.start_training() self.logger.info("Training process started") for epoch in range(1, self.args['epochs'] + 1): train_epoch_loss = self.train_epoch(epoch) @@ -124,7 +134,14 @@ class Trainer: torch.save(best_model, self.best_path) torch.save(best_test_model, self.best_test_path) self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") - + # 输出统计与参数 + self.stats.end_training() + self.stats.report(self.logger) + try: + total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + except Exception: + pass self._finalize_training(best_model, best_test_model) def _finalize_training(self, best_model, best_test_model): diff --git a/trainer/PDG2SEQ_Trainer.py b/trainer/PDG2SEQ_Trainer.py index bde4801..a6dc908 100755 --- a/trainer/PDG2SEQ_Trainer.py +++ b/trainer/PDG2SEQ_Trainer.py @@ -7,6 +7,7 @@ from tqdm import tqdm import torch from lib.logger import get_logger from lib.loss_function import all_metrics +from lib.training_stats import TrainingStats class Trainer: @@ -35,6 +36,8 @@ class Trainer: os.makedirs(args['log_dir'], exist_ok=True) self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug']) self.logger.info(f"Experiment log path in: {args['log_dir']}") + # Stats tracker + self.stats = TrainingStats(device=args['device']) def _run_epoch(self, epoch, dataloader, mode): if mode == 'train': @@ -50,6 +53,7 @@ class Trainer: with torch.set_grad_enabled(optimizer_step): with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: for batch_idx, (data, target) in enumerate(dataloader): + start_time = time.time() self.batches_seen += 1 label = target[..., :self.args['output_dim']].clone() output = self.model(data, target, self.batches_seen).to(self.args['device']) @@ -66,6 +70,9 @@ class Trainer: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) self.optimizer.step() + # record step time + step_time = time.time() - start_time + self.stats.record_step_time(step_time, mode) total_loss += loss.item() if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0: @@ -79,6 +86,8 @@ class Trainer: avg_loss = total_loss / len(dataloader) self.logger.info( f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') + # 记录内存 + self.stats.record_memory_usage() return avg_loss def train_epoch(self, epoch): @@ -95,6 +104,7 @@ class Trainer: best_loss, best_test_loss = float('inf'), float('inf') not_improved_count = 0 + self.stats.start_training() self.logger.info("Training process started") for epoch in range(1, self.args['epochs'] + 1): train_epoch_loss = self.train_epoch(epoch) @@ -127,6 +137,14 @@ class Trainer: torch.save(best_test_model, self.best_test_path) self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") + # 输出统计与参数 + self.stats.end_training() + self.stats.report(self.logger) + try: + total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + except Exception: + pass self._finalize_training(best_model, best_test_model) def _finalize_training(self, best_model, best_test_model): diff --git a/trainer/STMLP_Trainer.py b/trainer/STMLP_Trainer.py index 6489221..6e416e8 100644 --- a/trainer/STMLP_Trainer.py +++ b/trainer/STMLP_Trainer.py @@ -11,6 +11,7 @@ from tqdm import tqdm from lib.logger import get_logger from lib.loss_function import all_metrics from model.STMLP.STMLP import STMLP +from lib.training_stats import TrainingStats class Trainer: @@ -43,6 +44,8 @@ class Trainer: os.makedirs(self.pretrain_dir, exist_ok=True) self.logger = get_logger(self.args['log_dir'], name=self.model.__class__.__name__, debug=self.args['debug']) self.logger.info(f"Experiment log path in: {self.args['log_dir']}") + # Stats tracker + self.stats = TrainingStats(device=args['device']) if self.args['teacher_stu']: self.tmodel = self.loadTeacher(args) @@ -67,6 +70,7 @@ class Trainer: with torch.set_grad_enabled(optimizer_step): with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: for batch_idx, (data, target) in enumerate(dataloader): + start_time = time.time() if self.args['teacher_stu']: label = target[..., :self.args['output_dim']] output, out_, _ = self.model(data) @@ -100,6 +104,8 @@ class Trainer: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) self.optimizer.step() + step_time = time.time() - start_time + self.stats.record_step_time(step_time, mode) total_loss += loss.item() if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0: @@ -113,6 +119,8 @@ class Trainer: avg_loss = total_loss / len(dataloader) self.logger.info( f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') + # 记录内存 + self.stats.record_memory_usage() return avg_loss def train_epoch(self, epoch): @@ -129,6 +137,7 @@ class Trainer: best_loss, best_test_loss = float('inf'), float('inf') not_improved_count = 0 + self.stats.start_training() self.logger.info("Training process started") for epoch in range(1, self.args['epochs'] + 1): train_epoch_loss = self.train_epoch(epoch) @@ -165,6 +174,14 @@ class Trainer: torch.save(best_test_model, self.best_test_path) self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") + # 输出统计与参数 + self.stats.end_training() + self.stats.report(self.logger) + try: + total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + except Exception: + pass self._finalize_training(best_model, best_test_model) def _finalize_training(self, best_model, best_test_model): diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 9fbe3a9..013d852 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -211,6 +211,13 @@ class Trainer: self._finalize_training(best_model, best_test_model) + # 输出参数量 + try: + total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + except Exception: + pass + def _finalize_training(self, best_model, best_test_model): self.model.load_state_dict(best_model) self.logger.info("Testing on best validation model") diff --git a/trainer/Trainer_old.py b/trainer/Trainer_old.py index 5613870..4004f8a 100755 --- a/trainer/Trainer_old.py +++ b/trainer/Trainer_old.py @@ -7,6 +7,7 @@ from tqdm import tqdm import torch from lib.logger import get_logger from lib.loss_function import all_metrics +from lib.training_stats import TrainingStats class Trainer: @@ -34,6 +35,8 @@ class Trainer: os.makedirs(args['log_dir'], exist_ok=True) self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug']) self.logger.info(f"Experiment log path in: {args['log_dir']}") + # Stats tracker + self.stats = TrainingStats(device=args['device']) def _run_epoch(self, epoch, dataloader, mode): if mode == 'train': @@ -49,6 +52,7 @@ class Trainer: with torch.set_grad_enabled(optimizer_step): with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: for batch_idx, (data, target) in enumerate(dataloader): + start_time = time.time() label = target[..., :self.args['output_dim']] output = self.model(data).to(self.args['device']) @@ -64,6 +68,8 @@ class Trainer: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) self.optimizer.step() + step_time = time.time() - start_time + self.stats.record_step_time(step_time, mode) total_loss += loss.item() if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0: @@ -77,6 +83,8 @@ class Trainer: avg_loss = total_loss / len(dataloader) self.logger.info( f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') + # 记录内存 + self.stats.record_memory_usage() return avg_loss def train_epoch(self, epoch): @@ -93,6 +101,7 @@ class Trainer: best_loss, best_test_loss = float('inf'), float('inf') not_improved_count = 0 + self.stats.start_training() self.logger.info("Training process started") for epoch in range(1, self.args['epochs'] + 1): train_epoch_loss = self.train_epoch(epoch) @@ -125,6 +134,14 @@ class Trainer: torch.save(best_test_model, self.best_test_path) self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") + # 输出统计与参数 + self.stats.end_training() + self.stats.report(self.logger) + try: + total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + except Exception: + pass self._finalize_training(best_model, best_test_model) def _finalize_training(self, best_model, best_test_model): diff --git a/trainer/cdeTrainer/cdetrainer.py b/trainer/cdeTrainer/cdetrainer.py index 59fefa8..f939d66 100755 --- a/trainer/cdeTrainer/cdetrainer.py +++ b/trainer/cdeTrainer/cdetrainer.py @@ -7,6 +7,7 @@ from tqdm import tqdm import torch from lib.logger import get_logger from lib.loss_function import all_metrics +from lib.training_stats import TrainingStats class Trainer: @@ -35,6 +36,8 @@ class Trainer: os.makedirs(args['log_dir'], exist_ok=True) self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug']) self.logger.info(f"Experiment log path in: {args['log_dir']}") + # Stats tracker + self.stats = TrainingStats(device=args['device']) self.times = times.to(self.device, dtype=torch.float) self.w = w @@ -52,6 +55,7 @@ class Trainer: with torch.set_grad_enabled(optimizer_step): with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: for batch_idx, batch in enumerate(dataloader): + start_time = time.time() batch = tuple(b.to(self.device, dtype=torch.float) for b in batch) *train_coeffs, target = batch label = target[..., :self.args['output_dim']] @@ -69,6 +73,8 @@ class Trainer: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) self.optimizer.step() + step_time = time.time() - start_time + self.stats.record_step_time(step_time, mode) total_loss += loss.item() if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0: @@ -82,6 +88,8 @@ class Trainer: avg_loss = total_loss / len(dataloader) self.logger.info( f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') + # 记录内存 + self.stats.record_memory_usage() return avg_loss def train_epoch(self, epoch): return self._run_epoch(epoch, self.train_loader, 'train') @@ -97,6 +105,7 @@ class Trainer: best_loss, best_test_loss = float('inf'), float('inf') not_improved_count = 0 + self.stats.start_training() self.logger.info("Training process started") for epoch in range(1, self.args['epochs'] + 1): train_epoch_loss = self.train_epoch(epoch) @@ -129,6 +138,14 @@ class Trainer: torch.save(best_test_model, self.best_test_path) self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") + # 输出统计与参数 + self.stats.end_training() + self.stats.report(self.logger) + try: + total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + except Exception: + pass self._finalize_training(best_model, best_test_model) def _finalize_training(self, best_model, best_test_model): diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py index 57e48d6..5f66185 100755 --- a/trainer/trainer_selector.py +++ b/trainer/trainer_selector.py @@ -11,6 +11,8 @@ def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader match args['model']['type']: case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], lr_scheduler, kwargs[0], None) + case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], + lr_scheduler, kwargs[0], None) case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], lr_scheduler) case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],