diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..f383630 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // 使用 IntelliSense 了解相关属性。 + // 悬停以查看现有属性的描述。 + // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "EXP_PEMSD8", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/DDGCRN/PEMSD8.yaml" + } + ] +} \ No newline at end of file diff --git a/config/DDGCRN/PEMSD8.yaml b/config/DDGCRN/PEMSD8.yaml index 410845b..fb62006 100755 --- a/config/DDGCRN/PEMSD8.yaml +++ b/config/DDGCRN/PEMSD8.yaml @@ -1,3 +1,8 @@ +basic: + dataset: "PEMSD8" + mode : "train" + device : "cuda:0" + model: "DDGCRN" data: add_day_in_week: true add_time_in_day: true @@ -12,9 +17,10 @@ data: test_ratio: 0.2 tod: false val_ratio: 0.2 -log: - log_step: 2000 - plot: false + sample: 1 + input_dim: 1 + batch_size: 64 + model: cheb_order: 2 embed_dim: 5 @@ -24,9 +30,8 @@ model: rnn_units: 64 use_day: true use_week: true -test: - mae_thresh: None - mape_thresh: 0.001 + num_nodes: 170 + horizon: 12 train: batch_size: 64 early_stop: true @@ -42,3 +47,9 @@ train: real_value: true seed: 12 weight_decay: 0 + debug: false + output_dim: 1 + log_step: 2000 + plot: false + mae_thresh: None + mape_thresh: 0.001 diff --git a/config/EXP/PEMSD3.yaml b/config/EXP/PEMSD3.yaml index ca688e3..40fadfa 100755 --- a/config/EXP/PEMSD3.yaml +++ b/config/EXP/PEMSD3.yaml @@ -1,3 +1,5 @@ + + data: num_nodes: 358 lag: 12 diff --git a/config/EXP/PEMSD8.yaml b/config/EXP/PEMSD8.yaml index 67c3bd5..ffda65a 100755 --- a/config/EXP/PEMSD8.yaml +++ b/config/EXP/PEMSD8.yaml @@ -1,4 +1,11 @@ +basic: + dataset: "PEMSD8" + mode : "train" + device : "cuda:0" + model: "EXP" + data: + type: "PEMSD8" num_nodes: 170 lag: 12 horizon: 12 @@ -10,15 +17,25 @@ data: default_graph: True add_time_in_day: True add_day_in_week: True + input_dim: 1 + batch_size: 64 steps_per_day: 288 days_per_week: 7 + sample: 1 + cycle: 288 + model: + horizon: 12 + hidden_dim: 64 + num_nodes: 170 + time_slots: 288 + embed_dim: 16 batch_size: 64 input_dim: 1 output_dim: 1 in_len: 12 - + cycle_len: 288 train: loss_func: mae @@ -35,11 +52,12 @@ train: grad_norm: False max_grad_norm: 5 real_value: True + debug: True + output_dim: 1 + log_step: 2000 + plot: False test: mae_thresh: null mape_thresh: 0.0 -log: - log_step: 200 - plot: False diff --git a/config/args_parser.py b/config/args_parser.py index a028671..14ece82 100755 --- a/config/args_parser.py +++ b/config/args_parser.py @@ -3,43 +3,38 @@ import yaml def parse_args(): parser = argparse.ArgumentParser(description='Model Training and Testing') - parser.add_argument('--dataset', default='PEMSD8', type=str) - parser.add_argument('--mode', default='train', type=str) - parser.add_argument('--device', default='cuda:0', type=str, help='Indices of GPUs') - parser.add_argument('--debug', default=False, type=eval) - parser.add_argument('--model', default='GWN', type=str) - parser.add_argument('--cuda', default=True, type=bool) - parser.add_argument('--sample', default=1, type=int) - parser.add_argument('--emb', default=12, type=int) - parser.add_argument('--rnn', default=64, type=int) - - + parser.add_argument('--config', type=str, required=True, help='Path to the configuration file') args = parser.parse_args() # Load YAML configuration - config_file = f'./config/{args.model}/{args.dataset}.yaml' - with open(config_file, 'r') as file: - config = yaml.safe_load(file) + if args.config: + with open(args.config, 'r') as file: + config = yaml.safe_load(file) + else: + raise ValueError("Configuration file path must be provided using --config") + + # Update configuration with command-line arguments + # Merge 'basic' configuration into the root dictionary + # config.update(config.get('basic', {})) + + # Add adaptive configuration based on external commands + if 'data' in config and 'type' in config['data']: + config['data']['type'] = config['basic'].get('dataset', config['data']['type']) + if 'model' in config and 'type' in config['model']: + config['model']['type'] = config['basic'].get('model', config['model']['type']) + if 'model' in config and 'rnn_units' in config['model']: + config['model']['rnn_units'] = config['basic'].get('rnn', config['model']['rnn_units']) + if 'model' in config and 'embed_dim' in config['model']: + config['model']['embed_dim'] = config['basic'].get('emb', config['model']['embed_dim']) + if 'data' in config and 'sample' in config['data']: + config['data']['sample'] = config['basic'].get('sample', config['data']['sample']) + if 'train' in config and 'device' in config['train']: + config['train']['device'] = config['basic'].get('device', config['train']['device']) + if 'train' in config and 'debug' in config['train']: + config['train']['debug'] = config['basic'].get('debug', config['train']['debug']) + if 'cuda' in config: + config['cuda'] = config['basic'].get('cuda', config['cuda']) + if 'mode' in config: + config['mode'] = config['basic'].get('mode', config['mode']) - config['data']['type'] = args.dataset - config['model']['type'] = args.model - config['model']['rnn_units'] = args.rnn - config['model']['embed_dim'] = args.emb - config['data']['sample'] = args.sample - config['data']['input_dim'] = config['model']['input_dim'] - config['data']['output_dim'] = config['model']['output_dim'] - config['data']['batch_size'] = config['train']['batch_size'] - config['model']['num_nodes'] = config['data']['num_nodes'] - config['model']['horizon'] = config['data']['horizon'] - config['model']['default_graph'] = config['data']['default_graph'] - config['train']['device'] = args.device - config['train']['debug'] = args.debug - config['train']['log_step'] = config['log']['log_step'] - config['train']['output_dim'] = config['model']['output_dim'] - config['train']['mae_thresh'] = config['test']['mae_thresh'] - config['train']['mape_thresh'] = config['test']['mape_thresh'] - config['cuda'] = args.cuda - config['mode'] = args.mode - config['device'] = args.device - config['model']['device'] = config['device'] return config diff --git a/dataloader/PeMSDdataloader.py b/dataloader/PeMSDdataloader.py index ad6d8e5..22e0371 100755 --- a/dataloader/PeMSDdataloader.py +++ b/dataloader/PeMSDdataloader.py @@ -8,7 +8,8 @@ import h5py def get_dataloader(args, normalizer='std', single=True): - data = load_st_dataset(args['type'], args['sample']) # 加载数据 + data = load_st_dataset(args) # 加载数据 + args = args['data'] L, N, F = data.shape # 数据形状 # Step 1: data -> x,y @@ -90,7 +91,9 @@ def get_dataloader(args, normalizer='std', single=True): return train_dataloader, val_dataloader, test_dataloader, scaler -def load_st_dataset(dataset, sample): +def load_st_dataset(config): + dataset = config["basic"]["dataset"] + sample = config["data"]["sample"] # output B, N, D match dataset: case 'PEMSD3': diff --git a/dataloader/loader_selector.py b/dataloader/loader_selector.py index b7f697c..bfee98a 100755 --- a/dataloader/loader_selector.py +++ b/dataloader/loader_selector.py @@ -5,11 +5,12 @@ from dataloader.EXPdataloader import get_dataloader as EXP_loader from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader def get_dataloader(config, normalizer, single): - match config['model']['type']: - case 'STGNCDE': return cde_loader(config['data'], normalizer, single) - case 'STGNRDE': return nrde_loader(config['data'], normalizer, single) - case 'DCRNN': return DCRNN_loader(config['data'], normalizer, single) - case 'EXP': return EXP_loader(config['data'], normalizer, single) - case _: return normal_loader(config['data'], normalizer, single) + model_name = config["basic"]["model"] + match model_name: + case 'STGNCDE': return cde_loader(config, normalizer, single) + case 'STGNRDE': return nrde_loader(config, normalizer, single) + case 'DCRNN': return DCRNN_loader(config, normalizer, single) + case 'EXP': return EXP_loader(config, normalizer, single) + case _: return normal_loader(config, normalizer, single) diff --git a/lib/initializer.py b/lib/initializer.py index 2e3d321..ebbacd9 100755 --- a/lib/initializer.py +++ b/lib/initializer.py @@ -1,19 +1,23 @@ import torch import torch.nn as nn from model.model_selector import model_selector +from lib.loss_function import masked_mae_loss import random import numpy as np +from datetime import datetime +import os +import yaml -def init_model(args, device): +def init_model(args): + device = args["device"] model = model_selector(args).to(device) - # Initialize model parameters for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) else: nn.init.uniform_(p) total_params = sum(p.numel() for p in model.parameters()) - print(f"Model has {total_params} parameters") + print(f"模型参数量: {total_params} ") return model def init_optimizer(model, args): @@ -38,12 +42,59 @@ def init_optimizer(model, args): return optimizer, lr_scheduler def init_seed(seed): - ''' - Disable cudnn to maximize reproducibility - ''' + """初始化种子,保证结果可复现""" torch.cuda.cudnn_enabled = False torch.backends.cudnn.deterministic = True random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.cuda.manual_seed(seed) \ No newline at end of file + torch.cuda.manual_seed(seed) + +def init_device(args): + device_name = args['basic']['device'] + if 'model' not in args or not isinstance(args['model'], dict): + args['model'] = {} # Ensure args['model'] is a dictionary + match device_name: + case 'mps': + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + args['device'] = 'mps' + else: + args['device'] = 'cpu' + case device if device.startswith('cuda'): + if torch.cuda.is_available(): + torch.cuda.set_device(int(device.split(':')[1])) + args['device'] = device + else: + args['device'] = 'cpu' + case _: + args['device'] = 'cpu' + args['model']['device'] = args['device'] + return args + +def init_loss(args, scaler): + device = args['basic']['device'] + args = args['train'] + match args['loss_func']: + case 'mask_mae': + return masked_mae_loss(scaler, mask_value=None) + case 'mae': + return torch.nn.L1Loss().to(device) + case 'mse': + return torch.nn.MSELoss().to(device) + case 'Huber': + return torch.nn.HuberLoss().to(device) + case _: + raise ValueError(f"Unsupported loss function: {args['loss_func']}") + +def create_logs(args): + current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + current_dir = os.path.dirname(os.path.realpath(__file__)) + + args['train']['log_dir'] = os.path.join(current_dir, 'experiments', args['basic']['dataset'], current_time) + config_filename = f"{args['basic']['dataset']}.yaml" + os.makedirs(args['train']['log_dir'], exist_ok=True) + config_content = yaml.safe_dump(args, default_flow_style=False) + destination_path = os.path.join(args['train']['log_dir'], config_filename) + # 将 args 保存为 YAML 文件 + with open(destination_path, 'w') as f: + f.write(config_content) \ No newline at end of file diff --git a/model/model_selector.py b/model/model_selector.py index 4fba6bc..01e01f2 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -23,30 +23,32 @@ from model.ST_SSL.ST_SSL import STSSLModel from model.STGNRDE.Make_model import make_model as make_nrde_model from model.STAWnet.STAWnet import STAWnet -def model_selector(model): - match model['type']: - case 'DDGCRN': return DDGCRN(model) - case 'TWDGCN': return TWDGCN(model) - case 'AGCRN': return AGCRN(model) - case 'NLT': return HierAttnLstm(model) - case 'STGNCDE': return make_model(model) - case 'DSANET': return DSANet(model) - case 'STGCN': return STGCNChebGraphConv(model) - case 'DCRNN': return DCRNNModel(model) - case 'ARIMA': return ARIMA(model) - case 'TCN': return TemporalConvNet(model) - case 'GWN': return gwnet(model) - case 'STFGNN': return STFGNN(model) - case 'STSGCN': return STSGCN(model) - case 'STGODE': return ODEGCN(model) - case 'PDG2SEQ': return PDG2Seq(model) - case 'STMLP': return STMLP(model) - case 'STIDGCN': return STIDGCN(model) - case 'STID': return STID(model) - case 'STAEFormer': return STAEformer(model) - case 'EXP': return EXP(model) - case 'MegaCRN': return MegaCRNModel(model) - case 'ST_SSL': return STSSLModel(model) - case 'STGNRDE': return make_nrde_model(model) - case 'STAWnet': return STAWnet(model) +def model_selector(config): + model_name = config["basic"]["model"] + model_config = config["model"] + match model_name: + case 'DDGCRN': return DDGCRN(model_config) + case 'TWDGCN': return TWDGCN(model_config) + case 'AGCRN': return AGCRN(model_config) + case 'NLT': return HierAttnLstm(model_config) + case 'STGNCDE': return make_model(model_config) + case 'DSANET': return DSANet(model_config) + case 'STGCN': return STGCNChebGraphConv(model_config) + case 'DCRNN': return DCRNNModel(model_config) + case 'ARIMA': return ARIMA(model_config) + case 'TCN': return TemporalConvNet(model_config) + case 'GWN': return gwnet(model_config) + case 'STFGNN': return STFGNN(model_config) + case 'STSGCN': return STSGCN(model_config) + case 'STGODE': return ODEGCN(model_config) + case 'PDG2SEQ': return PDG2Seq(model_config) + case 'STMLP': return STMLP(model_config) + case 'STIDGCN': return STIDGCN(model_config) + case 'STID': return STID(model_config) + case 'STAEFormer': return STAEformer(model_config) + case 'EXP': return EXP(model_config) + case 'MegaCRN': return MegaCRNModel(model_config) + case 'ST_SSL': return STSSLModel(model_config) + case 'STGNRDE': return make_nrde_model(model_config) + case 'STAWnet': return STAWnet(model_config) diff --git a/run.py b/run.py index da48f85..a4fa1b3 100755 --- a/run.py +++ b/run.py @@ -1,64 +1,21 @@ import os -# 检查数据集完整性 -from lib.Download_data import check_and_download_data - -data_complete = check_and_download_data() -assert data_complete is not None, "数据集下载失败,请重试!" - import torch from datetime import datetime + # import time from config.args_parser import parse_args -from lib.initializer import init_model, init_optimizer, init_seed -from lib.loss_function import get_loss_function - +import lib.initializer as init from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer -import yaml - def main(): args = parse_args() - - # Set device (prefer MPS on macOS, then CUDA, else CPU) - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and args['device'] != 'cpu': - args['device'] = 'mps' - args['model']['device'] = args['device'] - elif torch.cuda.is_available() and args['device'] != 'cpu': - torch.cuda.set_device(int(args['device'].split(':')[1])) - args['model']['device'] = args['device'] - else: - args['device'] = 'cpu' - args['model']['device'] = args['device'] - init_seed(args['train']['seed']) - # Initialize model - model = init_model(args['model'], device=args['device']) - - - - if args['mode'] == "benchmark": - # 支持计算消耗分析,设置 mode为 benchmark - import torch.profiler as profiler - dummy_input = torch.randn((64, 12, args['model']['num_nodes'], 3), device=args['device']) - min_val = dummy_input.min(dim=-1, keepdim=True)[0] - max_val = dummy_input.max(dim=-1, keepdim=True)[0] - - dummy_input = (dummy_input - min_val) / (max_val - min_val + 1e-6) - with profiler.profile( - activities=[ - profiler.ProfilerActivity.CPU, - profiler.ProfilerActivity.CUDA - ], - with_stack=True, - profile_memory=True, - record_shapes=True - ) as prof: - out = model(dummy_input) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - return 0 + args = init.init_device(args) + init.init_seed(args['train']['seed']) + model = init.init_model(args) # Load dataset train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( @@ -67,49 +24,28 @@ def main(): single=False ) - # Initialize loss function - loss = get_loss_function(args['train'], scaler) - - # Initialize optimizer and learning rate scheduler - optimizer, lr_scheduler = init_optimizer(model, args['train']) - - # Configure log path - current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') - current_dir = os.path.dirname(os.path.realpath(__file__)) - args['train']['log_dir'] = os.path.join(current_dir, 'experiments', args['data']['type'], current_time) - - # 配置文件路径 - config_filename = f"{args['data']['type']}.yaml" - # config_path = os.path.join(current_dir, 'config', args['model']['type'], config_filename) - # 确保日志目录存在 - os.makedirs(args['train']['log_dir'], exist_ok=True) - - # 生成配置文件内容(将 args 转换为 YAML 格式) - config_content = yaml.safe_dump(args, default_flow_style=False) - - # 生成新的 YAML 文件名(例如:config.auto.yaml 或其他名称) - destination_path = os.path.join(args['train']['log_dir'], config_filename) - - # 将 args 保存为 YAML 文件 - with open(destination_path, 'w') as f: - f.write(config_content) + loss = init.init_loss(args, scaler) + optimizer, lr_scheduler = init.init_optimizer(model, args['train']) + init.create_logs(args) # Start training or testing trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, extra_data) - match args['mode']: + match args['basic']['mode']: case 'train': trainer.train() case 'test': model.load_state_dict(torch.load( - f"./pre-trained/{args['model']['type']}/{args['data']['type']}.pth", + f"./pre-trained/{args['basic']['model']}/{args['basic']['dataset']}.pth", map_location=args['device'], weights_only=True)) - # print(f"Loaded saved model on {args['device']}") - trainer.test(model.to(args['device']), trainer.args, test_loader, scaler, trainer.logger) + trainer.test(model.to(args['basic']['device']), trainer.args, test_loader, scaler, trainer.logger) case _: - raise ValueError(f"Unsupported mode: {args['mode']}") + raise ValueError(f"Unsupported mode: {args['basic']['mode']}") if __name__ == '__main__': + from lib.Download_data import check_and_download_data + data_complete = check_and_download_data() + assert data_complete is not None, "数据集下载失败,请重试!" main() diff --git a/trainer/E32Trainer.py b/trainer/E32Trainer.py index b1bce7c..4154ba4 100644 --- a/trainer/E32Trainer.py +++ b/trainer/E32Trainer.py @@ -12,7 +12,10 @@ from lib.training_stats import TrainingStats class Trainer: def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, - scaler, args, lr_scheduler=None): + scaler, global_config, lr_scheduler=None): + + self.device = global_config['basic']['device'] + train_config = global_config['train'] self.model = model self.loss = loss self.optimizer = optimizer @@ -20,23 +23,23 @@ class Trainer: self.val_loader = val_loader self.test_loader = test_loader self.scaler = scaler - self.args = args + self.args = train_config self.lr_scheduler = lr_scheduler self.train_per_epoch = len(train_loader) self.val_per_epoch = len(val_loader) if val_loader else 0 # Paths for saving models and logs - self.best_path = os.path.join(args['log_dir'], 'best_model.pth') - self.best_test_path = os.path.join(args['log_dir'], 'best_test_model.pth') - self.loss_figure_path = os.path.join(args['log_dir'], 'loss.png') + self.best_path = os.path.join(train_config['log_dir'], 'best_model.pth') + self.best_test_path = os.path.join(train_config['log_dir'], 'best_test_model.pth') + self.loss_figure_path = os.path.join(train_config['log_dir'], 'loss.png') # Initialize logger - if not os.path.isdir(args['log_dir']) and not args['debug']: - os.makedirs(args['log_dir'], exist_ok=True) - self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug']) - self.logger.info(f"Experiment log path in: {args['log_dir']}") + if not os.path.isdir(train_config['log_dir']) and not train_config['debug']: + os.makedirs(train_config['log_dir'], exist_ok=True) + self.logger = get_logger(train_config['log_dir'], name=self.model.__class__.__name__, debug=train_config['debug']) + self.logger.info(f"Experiment log path in: {train_config['log_dir']}") # Stats tracker - self.stats = TrainingStats(device=args['device']) + self.stats = TrainingStats(device=self.device) def _run_epoch(self, epoch, dataloader, mode): is_train = (mode == 'train') @@ -51,9 +54,9 @@ class Trainer: start_time = time.time() # unpack the new cycle_index data, target, cycle_index = batch - data = data.to(self.args['device']) - target = target.to(self.args['device']) - cycle_index = cycle_index.to(self.args['device']).long() + data = data.to(self.device) + target = target.to(self.device) + cycle_index = cycle_index.to(self.device).long() # forward if is_train: @@ -164,10 +167,13 @@ class Trainer: @staticmethod def test(model, args, data_loader, scaler, logger, path=None): + global_config = args + device = global_config['basic']['device'] + args = global_config['train'] if path: checkpoint = torch.load(path) model.load_state_dict(checkpoint['state_dict']) - model.to(args['device']) + model.to(device) model.eval() y_pred, y_true = [], [] diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 013d852..b11035d 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -80,6 +80,8 @@ class TrainingStats: class Trainer: def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): + self.device = args['basic']['device'] + args = args['train'] self.model = model self.loss = loss self.optimizer = optimizer @@ -104,7 +106,7 @@ class Trainer: self.logger.info(f"Experiment log path in: {args['log_dir']}") # Stats tracker - self.stats = TrainingStats(device=args['device']) + self.stats = TrainingStats(device=self.device) def _run_epoch(self, epoch, dataloader, mode): if mode == 'train': @@ -123,7 +125,7 @@ class Trainer: start_time = time.time() label = target[..., :self.args['output_dim']] - output = self.model(data).to(self.args['device']) + output = self.model(data).to(self.device) if self.args['real_value']: output = self.scaler.inverse_transform(output) @@ -153,6 +155,12 @@ class Trainer: self.logger.info( f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') + # 输出指标 + mae, rmse, mape = all_metrics(output, label, self.args['mae_thresh'], self.args['mape_thresh']) + self.logger.info( + f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:5.2f}" + ) + # 记录内存 self.stats.record_memory_usage() @@ -232,19 +240,19 @@ class Trainer: if path: checkpoint = torch.load(path) model.load_state_dict(checkpoint['state_dict']) - model.to(args['device']) + model.to(args['basic']['device']) model.eval() y_pred, y_true = [], [] with torch.no_grad(): for data, target in data_loader: - label = target[..., :args['output_dim']] + label = target[..., :args['train']['output_dim']] output = model(data) y_pred.append(output) y_true.append(label) - if args['real_value']: + if args['train']['real_value']: y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) else: y_pred = torch.cat(y_pred, dim=0) @@ -252,10 +260,10 @@ class Trainer: for t in range(y_true.shape[1]): mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], - args['mae_thresh'], args['mape_thresh']) + args['train']['mae_thresh'], args['train']['mape_thresh']) logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") - mae, rmse, mape = all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh']) + mae, rmse, mape = all_metrics(y_pred, y_true, args['train']['mae_thresh'], args['train']['mape_thresh']) logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") @staticmethod diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py index 5f66185..5b4b501 100755 --- a/trainer/trainer_selector.py +++ b/trainer/trainer_selector.py @@ -8,18 +8,19 @@ from trainer.E32Trainer import Trainer as EXP_Trainer def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, kwargs): - match args['model']['type']: - case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], + model_name = args['basic']['model'] + match model_name: + case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, kwargs[0], None) - case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], + case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, kwargs[0], None) - case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], + case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler) - case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], + case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler) case 'STMLP': return STMLP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler) - case 'EXP': return EXP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], + case 'EXP': return EXP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler) - case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], + case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler)