预备升级为2.0

This commit is contained in:
czzhangheng 2025-11-08 21:05:52 +08:00
parent b26e5a823c
commit 9f9abd9d1b
13 changed files with 243 additions and 193 deletions

16
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,16 @@
{
// 使 IntelliSense
//
// 访: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "EXP_PEMSD8",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/DDGCRN/PEMSD8.yaml"
}
]
}

View File

@ -1,3 +1,8 @@
basic:
dataset: "PEMSD8"
mode : "train"
device : "cuda:0"
model: "DDGCRN"
data: data:
add_day_in_week: true add_day_in_week: true
add_time_in_day: true add_time_in_day: true
@ -12,9 +17,10 @@ data:
test_ratio: 0.2 test_ratio: 0.2
tod: false tod: false
val_ratio: 0.2 val_ratio: 0.2
log: sample: 1
log_step: 2000 input_dim: 1
plot: false batch_size: 64
model: model:
cheb_order: 2 cheb_order: 2
embed_dim: 5 embed_dim: 5
@ -24,9 +30,8 @@ model:
rnn_units: 64 rnn_units: 64
use_day: true use_day: true
use_week: true use_week: true
test: num_nodes: 170
mae_thresh: None horizon: 12
mape_thresh: 0.001
train: train:
batch_size: 64 batch_size: 64
early_stop: true early_stop: true
@ -42,3 +47,9 @@ train:
real_value: true real_value: true
seed: 12 seed: 12
weight_decay: 0 weight_decay: 0
debug: false
output_dim: 1
log_step: 2000
plot: false
mae_thresh: None
mape_thresh: 0.001

View File

@ -1,3 +1,5 @@
data: data:
num_nodes: 358 num_nodes: 358
lag: 12 lag: 12

View File

@ -1,4 +1,11 @@
basic:
dataset: "PEMSD8"
mode : "train"
device : "cuda:0"
model: "EXP"
data: data:
type: "PEMSD8"
num_nodes: 170 num_nodes: 170
lag: 12 lag: 12
horizon: 12 horizon: 12
@ -10,15 +17,25 @@ data:
default_graph: True default_graph: True
add_time_in_day: True add_time_in_day: True
add_day_in_week: True add_day_in_week: True
input_dim: 1
batch_size: 64
steps_per_day: 288 steps_per_day: 288
days_per_week: 7 days_per_week: 7
sample: 1
cycle: 288
model: model:
horizon: 12
hidden_dim: 64
num_nodes: 170
time_slots: 288
embed_dim: 16
batch_size: 64 batch_size: 64
input_dim: 1 input_dim: 1
output_dim: 1 output_dim: 1
in_len: 12 in_len: 12
cycle_len: 288
train: train:
loss_func: mae loss_func: mae
@ -35,11 +52,12 @@ train:
grad_norm: False grad_norm: False
max_grad_norm: 5 max_grad_norm: 5
real_value: True real_value: True
debug: True
output_dim: 1
log_step: 2000
plot: False
test: test:
mae_thresh: null mae_thresh: null
mape_thresh: 0.0 mape_thresh: 0.0
log:
log_step: 200
plot: False

View File

@ -3,43 +3,38 @@ import yaml
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Model Training and Testing') parser = argparse.ArgumentParser(description='Model Training and Testing')
parser.add_argument('--dataset', default='PEMSD8', type=str) parser.add_argument('--config', type=str, required=True, help='Path to the configuration file')
parser.add_argument('--mode', default='train', type=str)
parser.add_argument('--device', default='cuda:0', type=str, help='Indices of GPUs')
parser.add_argument('--debug', default=False, type=eval)
parser.add_argument('--model', default='GWN', type=str)
parser.add_argument('--cuda', default=True, type=bool)
parser.add_argument('--sample', default=1, type=int)
parser.add_argument('--emb', default=12, type=int)
parser.add_argument('--rnn', default=64, type=int)
args = parser.parse_args() args = parser.parse_args()
# Load YAML configuration # Load YAML configuration
config_file = f'./config/{args.model}/{args.dataset}.yaml' if args.config:
with open(config_file, 'r') as file: with open(args.config, 'r') as file:
config = yaml.safe_load(file) config = yaml.safe_load(file)
else:
raise ValueError("Configuration file path must be provided using --config")
# Update configuration with command-line arguments
# Merge 'basic' configuration into the root dictionary
# config.update(config.get('basic', {}))
# Add adaptive configuration based on external commands
if 'data' in config and 'type' in config['data']:
config['data']['type'] = config['basic'].get('dataset', config['data']['type'])
if 'model' in config and 'type' in config['model']:
config['model']['type'] = config['basic'].get('model', config['model']['type'])
if 'model' in config and 'rnn_units' in config['model']:
config['model']['rnn_units'] = config['basic'].get('rnn', config['model']['rnn_units'])
if 'model' in config and 'embed_dim' in config['model']:
config['model']['embed_dim'] = config['basic'].get('emb', config['model']['embed_dim'])
if 'data' in config and 'sample' in config['data']:
config['data']['sample'] = config['basic'].get('sample', config['data']['sample'])
if 'train' in config and 'device' in config['train']:
config['train']['device'] = config['basic'].get('device', config['train']['device'])
if 'train' in config and 'debug' in config['train']:
config['train']['debug'] = config['basic'].get('debug', config['train']['debug'])
if 'cuda' in config:
config['cuda'] = config['basic'].get('cuda', config['cuda'])
if 'mode' in config:
config['mode'] = config['basic'].get('mode', config['mode'])
config['data']['type'] = args.dataset
config['model']['type'] = args.model
config['model']['rnn_units'] = args.rnn
config['model']['embed_dim'] = args.emb
config['data']['sample'] = args.sample
config['data']['input_dim'] = config['model']['input_dim']
config['data']['output_dim'] = config['model']['output_dim']
config['data']['batch_size'] = config['train']['batch_size']
config['model']['num_nodes'] = config['data']['num_nodes']
config['model']['horizon'] = config['data']['horizon']
config['model']['default_graph'] = config['data']['default_graph']
config['train']['device'] = args.device
config['train']['debug'] = args.debug
config['train']['log_step'] = config['log']['log_step']
config['train']['output_dim'] = config['model']['output_dim']
config['train']['mae_thresh'] = config['test']['mae_thresh']
config['train']['mape_thresh'] = config['test']['mape_thresh']
config['cuda'] = args.cuda
config['mode'] = args.mode
config['device'] = args.device
config['model']['device'] = config['device']
return config return config

View File

@ -8,7 +8,8 @@ import h5py
def get_dataloader(args, normalizer='std', single=True): def get_dataloader(args, normalizer='std', single=True):
data = load_st_dataset(args['type'], args['sample']) # 加载数据 data = load_st_dataset(args) # 加载数据
args = args['data']
L, N, F = data.shape # 数据形状 L, N, F = data.shape # 数据形状
# Step 1: data -> x,y # Step 1: data -> x,y
@ -90,7 +91,9 @@ def get_dataloader(args, normalizer='std', single=True):
return train_dataloader, val_dataloader, test_dataloader, scaler return train_dataloader, val_dataloader, test_dataloader, scaler
def load_st_dataset(dataset, sample): def load_st_dataset(config):
dataset = config["basic"]["dataset"]
sample = config["data"]["sample"]
# output B, N, D # output B, N, D
match dataset: match dataset:
case 'PEMSD3': case 'PEMSD3':

View File

@ -5,11 +5,12 @@ from dataloader.EXPdataloader import get_dataloader as EXP_loader
from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader
def get_dataloader(config, normalizer, single): def get_dataloader(config, normalizer, single):
match config['model']['type']: model_name = config["basic"]["model"]
case 'STGNCDE': return cde_loader(config['data'], normalizer, single) match model_name:
case 'STGNRDE': return nrde_loader(config['data'], normalizer, single) case 'STGNCDE': return cde_loader(config, normalizer, single)
case 'DCRNN': return DCRNN_loader(config['data'], normalizer, single) case 'STGNRDE': return nrde_loader(config, normalizer, single)
case 'EXP': return EXP_loader(config['data'], normalizer, single) case 'DCRNN': return DCRNN_loader(config, normalizer, single)
case _: return normal_loader(config['data'], normalizer, single) case 'EXP': return EXP_loader(config, normalizer, single)
case _: return normal_loader(config, normalizer, single)

View File

@ -1,19 +1,23 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from model.model_selector import model_selector from model.model_selector import model_selector
from lib.loss_function import masked_mae_loss
import random import random
import numpy as np import numpy as np
from datetime import datetime
import os
import yaml
def init_model(args, device): def init_model(args):
device = args["device"]
model = model_selector(args).to(device) model = model_selector(args).to(device)
# Initialize model parameters
for p in model.parameters(): for p in model.parameters():
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
else: else:
nn.init.uniform_(p) nn.init.uniform_(p)
total_params = sum(p.numel() for p in model.parameters()) total_params = sum(p.numel() for p in model.parameters())
print(f"Model has {total_params} parameters") print(f"模型参数量: {total_params} ")
return model return model
def init_optimizer(model, args): def init_optimizer(model, args):
@ -38,12 +42,59 @@ def init_optimizer(model, args):
return optimizer, lr_scheduler return optimizer, lr_scheduler
def init_seed(seed): def init_seed(seed):
''' """初始化种子,保证结果可复现"""
Disable cudnn to maximize reproducibility
'''
torch.cuda.cudnn_enabled = False torch.cuda.cudnn_enabled = False
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
def init_device(args):
device_name = args['basic']['device']
if 'model' not in args or not isinstance(args['model'], dict):
args['model'] = {} # Ensure args['model'] is a dictionary
match device_name:
case 'mps':
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
args['device'] = 'mps'
else:
args['device'] = 'cpu'
case device if device.startswith('cuda'):
if torch.cuda.is_available():
torch.cuda.set_device(int(device.split(':')[1]))
args['device'] = device
else:
args['device'] = 'cpu'
case _:
args['device'] = 'cpu'
args['model']['device'] = args['device']
return args
def init_loss(args, scaler):
device = args['basic']['device']
args = args['train']
match args['loss_func']:
case 'mask_mae':
return masked_mae_loss(scaler, mask_value=None)
case 'mae':
return torch.nn.L1Loss().to(device)
case 'mse':
return torch.nn.MSELoss().to(device)
case 'Huber':
return torch.nn.HuberLoss().to(device)
case _:
raise ValueError(f"Unsupported loss function: {args['loss_func']}")
def create_logs(args):
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
current_dir = os.path.dirname(os.path.realpath(__file__))
args['train']['log_dir'] = os.path.join(current_dir, 'experiments', args['basic']['dataset'], current_time)
config_filename = f"{args['basic']['dataset']}.yaml"
os.makedirs(args['train']['log_dir'], exist_ok=True)
config_content = yaml.safe_dump(args, default_flow_style=False)
destination_path = os.path.join(args['train']['log_dir'], config_filename)
# 将 args 保存为 YAML 文件
with open(destination_path, 'w') as f:
f.write(config_content)

View File

@ -23,30 +23,32 @@ from model.ST_SSL.ST_SSL import STSSLModel
from model.STGNRDE.Make_model import make_model as make_nrde_model from model.STGNRDE.Make_model import make_model as make_nrde_model
from model.STAWnet.STAWnet import STAWnet from model.STAWnet.STAWnet import STAWnet
def model_selector(model): def model_selector(config):
match model['type']: model_name = config["basic"]["model"]
case 'DDGCRN': return DDGCRN(model) model_config = config["model"]
case 'TWDGCN': return TWDGCN(model) match model_name:
case 'AGCRN': return AGCRN(model) case 'DDGCRN': return DDGCRN(model_config)
case 'NLT': return HierAttnLstm(model) case 'TWDGCN': return TWDGCN(model_config)
case 'STGNCDE': return make_model(model) case 'AGCRN': return AGCRN(model_config)
case 'DSANET': return DSANet(model) case 'NLT': return HierAttnLstm(model_config)
case 'STGCN': return STGCNChebGraphConv(model) case 'STGNCDE': return make_model(model_config)
case 'DCRNN': return DCRNNModel(model) case 'DSANET': return DSANet(model_config)
case 'ARIMA': return ARIMA(model) case 'STGCN': return STGCNChebGraphConv(model_config)
case 'TCN': return TemporalConvNet(model) case 'DCRNN': return DCRNNModel(model_config)
case 'GWN': return gwnet(model) case 'ARIMA': return ARIMA(model_config)
case 'STFGNN': return STFGNN(model) case 'TCN': return TemporalConvNet(model_config)
case 'STSGCN': return STSGCN(model) case 'GWN': return gwnet(model_config)
case 'STGODE': return ODEGCN(model) case 'STFGNN': return STFGNN(model_config)
case 'PDG2SEQ': return PDG2Seq(model) case 'STSGCN': return STSGCN(model_config)
case 'STMLP': return STMLP(model) case 'STGODE': return ODEGCN(model_config)
case 'STIDGCN': return STIDGCN(model) case 'PDG2SEQ': return PDG2Seq(model_config)
case 'STID': return STID(model) case 'STMLP': return STMLP(model_config)
case 'STAEFormer': return STAEformer(model) case 'STIDGCN': return STIDGCN(model_config)
case 'EXP': return EXP(model) case 'STID': return STID(model_config)
case 'MegaCRN': return MegaCRNModel(model) case 'STAEFormer': return STAEformer(model_config)
case 'ST_SSL': return STSSLModel(model) case 'EXP': return EXP(model_config)
case 'STGNRDE': return make_nrde_model(model) case 'MegaCRN': return MegaCRNModel(model_config)
case 'STAWnet': return STAWnet(model) case 'ST_SSL': return STSSLModel(model_config)
case 'STGNRDE': return make_nrde_model(model_config)
case 'STAWnet': return STAWnet(model_config)

94
run.py
View File

@ -1,64 +1,21 @@
import os import os
# 检查数据集完整性
from lib.Download_data import check_and_download_data
data_complete = check_and_download_data()
assert data_complete is not None, "数据集下载失败,请重试!"
import torch import torch
from datetime import datetime from datetime import datetime
# import time # import time
from config.args_parser import parse_args from config.args_parser import parse_args
from lib.initializer import init_model, init_optimizer, init_seed import lib.initializer as init
from lib.loss_function import get_loss_function
from dataloader.loader_selector import get_dataloader from dataloader.loader_selector import get_dataloader
from trainer.trainer_selector import select_trainer from trainer.trainer_selector import select_trainer
import yaml
def main(): def main():
args = parse_args() args = parse_args()
args = init.init_device(args)
# Set device (prefer MPS on macOS, then CUDA, else CPU) init.init_seed(args['train']['seed'])
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and args['device'] != 'cpu': model = init.init_model(args)
args['device'] = 'mps'
args['model']['device'] = args['device']
elif torch.cuda.is_available() and args['device'] != 'cpu':
torch.cuda.set_device(int(args['device'].split(':')[1]))
args['model']['device'] = args['device']
else:
args['device'] = 'cpu'
args['model']['device'] = args['device']
init_seed(args['train']['seed'])
# Initialize model
model = init_model(args['model'], device=args['device'])
if args['mode'] == "benchmark":
# 支持计算消耗分析,设置 mode为 benchmark
import torch.profiler as profiler
dummy_input = torch.randn((64, 12, args['model']['num_nodes'], 3), device=args['device'])
min_val = dummy_input.min(dim=-1, keepdim=True)[0]
max_val = dummy_input.max(dim=-1, keepdim=True)[0]
dummy_input = (dummy_input - min_val) / (max_val - min_val + 1e-6)
with profiler.profile(
activities=[
profiler.ProfilerActivity.CPU,
profiler.ProfilerActivity.CUDA
],
with_stack=True,
profile_memory=True,
record_shapes=True
) as prof:
out = model(dummy_input)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
return 0
# Load dataset # Load dataset
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
@ -67,49 +24,28 @@ def main():
single=False single=False
) )
# Initialize loss function loss = init.init_loss(args, scaler)
loss = get_loss_function(args['train'], scaler) optimizer, lr_scheduler = init.init_optimizer(model, args['train'])
init.create_logs(args)
# Initialize optimizer and learning rate scheduler
optimizer, lr_scheduler = init_optimizer(model, args['train'])
# Configure log path
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
current_dir = os.path.dirname(os.path.realpath(__file__))
args['train']['log_dir'] = os.path.join(current_dir, 'experiments', args['data']['type'], current_time)
# 配置文件路径
config_filename = f"{args['data']['type']}.yaml"
# config_path = os.path.join(current_dir, 'config', args['model']['type'], config_filename)
# 确保日志目录存在
os.makedirs(args['train']['log_dir'], exist_ok=True)
# 生成配置文件内容(将 args 转换为 YAML 格式)
config_content = yaml.safe_dump(args, default_flow_style=False)
# 生成新的 YAML 文件名例如config.auto.yaml 或其他名称)
destination_path = os.path.join(args['train']['log_dir'], config_filename)
# 将 args 保存为 YAML 文件
with open(destination_path, 'w') as f:
f.write(config_content)
# Start training or testing # Start training or testing
trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
lr_scheduler, extra_data) lr_scheduler, extra_data)
match args['mode']: match args['basic']['mode']:
case 'train': case 'train':
trainer.train() trainer.train()
case 'test': case 'test':
model.load_state_dict(torch.load( model.load_state_dict(torch.load(
f"./pre-trained/{args['model']['type']}/{args['data']['type']}.pth", f"./pre-trained/{args['basic']['model']}/{args['basic']['dataset']}.pth",
map_location=args['device'], weights_only=True)) map_location=args['device'], weights_only=True))
# print(f"Loaded saved model on {args['device']}") trainer.test(model.to(args['basic']['device']), trainer.args, test_loader, scaler, trainer.logger)
trainer.test(model.to(args['device']), trainer.args, test_loader, scaler, trainer.logger)
case _: case _:
raise ValueError(f"Unsupported mode: {args['mode']}") raise ValueError(f"Unsupported mode: {args['basic']['mode']}")
if __name__ == '__main__': if __name__ == '__main__':
from lib.Download_data import check_and_download_data
data_complete = check_and_download_data()
assert data_complete is not None, "数据集下载失败,请重试!"
main() main()

View File

@ -12,7 +12,10 @@ from lib.training_stats import TrainingStats
class Trainer: class Trainer:
def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader,
scaler, args, lr_scheduler=None): scaler, global_config, lr_scheduler=None):
self.device = global_config['basic']['device']
train_config = global_config['train']
self.model = model self.model = model
self.loss = loss self.loss = loss
self.optimizer = optimizer self.optimizer = optimizer
@ -20,23 +23,23 @@ class Trainer:
self.val_loader = val_loader self.val_loader = val_loader
self.test_loader = test_loader self.test_loader = test_loader
self.scaler = scaler self.scaler = scaler
self.args = args self.args = train_config
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
self.train_per_epoch = len(train_loader) self.train_per_epoch = len(train_loader)
self.val_per_epoch = len(val_loader) if val_loader else 0 self.val_per_epoch = len(val_loader) if val_loader else 0
# Paths for saving models and logs # Paths for saving models and logs
self.best_path = os.path.join(args['log_dir'], 'best_model.pth') self.best_path = os.path.join(train_config['log_dir'], 'best_model.pth')
self.best_test_path = os.path.join(args['log_dir'], 'best_test_model.pth') self.best_test_path = os.path.join(train_config['log_dir'], 'best_test_model.pth')
self.loss_figure_path = os.path.join(args['log_dir'], 'loss.png') self.loss_figure_path = os.path.join(train_config['log_dir'], 'loss.png')
# Initialize logger # Initialize logger
if not os.path.isdir(args['log_dir']) and not args['debug']: if not os.path.isdir(train_config['log_dir']) and not train_config['debug']:
os.makedirs(args['log_dir'], exist_ok=True) os.makedirs(train_config['log_dir'], exist_ok=True)
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug']) self.logger = get_logger(train_config['log_dir'], name=self.model.__class__.__name__, debug=train_config['debug'])
self.logger.info(f"Experiment log path in: {args['log_dir']}") self.logger.info(f"Experiment log path in: {train_config['log_dir']}")
# Stats tracker # Stats tracker
self.stats = TrainingStats(device=args['device']) self.stats = TrainingStats(device=self.device)
def _run_epoch(self, epoch, dataloader, mode): def _run_epoch(self, epoch, dataloader, mode):
is_train = (mode == 'train') is_train = (mode == 'train')
@ -51,9 +54,9 @@ class Trainer:
start_time = time.time() start_time = time.time()
# unpack the new cycle_index # unpack the new cycle_index
data, target, cycle_index = batch data, target, cycle_index = batch
data = data.to(self.args['device']) data = data.to(self.device)
target = target.to(self.args['device']) target = target.to(self.device)
cycle_index = cycle_index.to(self.args['device']).long() cycle_index = cycle_index.to(self.device).long()
# forward # forward
if is_train: if is_train:
@ -164,10 +167,13 @@ class Trainer:
@staticmethod @staticmethod
def test(model, args, data_loader, scaler, logger, path=None): def test(model, args, data_loader, scaler, logger, path=None):
global_config = args
device = global_config['basic']['device']
args = global_config['train']
if path: if path:
checkpoint = torch.load(path) checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
model.to(args['device']) model.to(device)
model.eval() model.eval()
y_pred, y_true = [], [] y_pred, y_true = [], []

View File

@ -80,6 +80,8 @@ class TrainingStats:
class Trainer: class Trainer:
def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader,
scaler, args, lr_scheduler=None): scaler, args, lr_scheduler=None):
self.device = args['basic']['device']
args = args['train']
self.model = model self.model = model
self.loss = loss self.loss = loss
self.optimizer = optimizer self.optimizer = optimizer
@ -104,7 +106,7 @@ class Trainer:
self.logger.info(f"Experiment log path in: {args['log_dir']}") self.logger.info(f"Experiment log path in: {args['log_dir']}")
# Stats tracker # Stats tracker
self.stats = TrainingStats(device=args['device']) self.stats = TrainingStats(device=self.device)
def _run_epoch(self, epoch, dataloader, mode): def _run_epoch(self, epoch, dataloader, mode):
if mode == 'train': if mode == 'train':
@ -123,7 +125,7 @@ class Trainer:
start_time = time.time() start_time = time.time()
label = target[..., :self.args['output_dim']] label = target[..., :self.args['output_dim']]
output = self.model(data).to(self.args['device']) output = self.model(data).to(self.device)
if self.args['real_value']: if self.args['real_value']:
output = self.scaler.inverse_transform(output) output = self.scaler.inverse_transform(output)
@ -153,6 +155,12 @@ class Trainer:
self.logger.info( self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 输出指标
mae, rmse, mape = all_metrics(output, label, self.args['mae_thresh'], self.args['mape_thresh'])
self.logger.info(
f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:5.2f}"
)
# 记录内存 # 记录内存
self.stats.record_memory_usage() self.stats.record_memory_usage()
@ -232,19 +240,19 @@ class Trainer:
if path: if path:
checkpoint = torch.load(path) checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
model.to(args['device']) model.to(args['basic']['device'])
model.eval() model.eval()
y_pred, y_true = [], [] y_pred, y_true = [], []
with torch.no_grad(): with torch.no_grad():
for data, target in data_loader: for data, target in data_loader:
label = target[..., :args['output_dim']] label = target[..., :args['train']['output_dim']]
output = model(data) output = model(data)
y_pred.append(output) y_pred.append(output)
y_true.append(label) y_true.append(label)
if args['real_value']: if args['train']['real_value']:
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
else: else:
y_pred = torch.cat(y_pred, dim=0) y_pred = torch.cat(y_pred, dim=0)
@ -252,10 +260,10 @@ class Trainer:
for t in range(y_true.shape[1]): for t in range(y_true.shape[1]):
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
args['mae_thresh'], args['mape_thresh']) args['train']['mae_thresh'], args['train']['mape_thresh'])
logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
mae, rmse, mape = all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh']) mae, rmse, mape = all_metrics(y_pred, y_true, args['train']['mae_thresh'], args['train']['mape_thresh'])
logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
@staticmethod @staticmethod

View File

@ -8,18 +8,19 @@ from trainer.E32Trainer import Trainer as EXP_Trainer
def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
lr_scheduler, kwargs): lr_scheduler, kwargs):
match args['model']['type']: model_name = args['basic']['model']
case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], match model_name:
case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
lr_scheduler, kwargs[0], None) lr_scheduler, kwargs[0], None)
case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
lr_scheduler, kwargs[0], None) lr_scheduler, kwargs[0], None)
case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
lr_scheduler) lr_scheduler)
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
lr_scheduler) lr_scheduler)
case 'STMLP': return STMLP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, case 'STMLP': return STMLP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
lr_scheduler) lr_scheduler)
case 'EXP': return EXP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], case 'EXP': return EXP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
lr_scheduler) lr_scheduler)
case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
lr_scheduler) lr_scheduler)