预备升级为2.0
This commit is contained in:
parent
b26e5a823c
commit
9f9abd9d1b
|
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
// 使用 IntelliSense 了解相关属性。
|
||||
// 悬停以查看现有属性的描述。
|
||||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "EXP_PEMSD8",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "run.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": "--config ./config/DDGCRN/PEMSD8.yaml"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -1,3 +1,8 @@
|
|||
basic:
|
||||
dataset: "PEMSD8"
|
||||
mode : "train"
|
||||
device : "cuda:0"
|
||||
model: "DDGCRN"
|
||||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
|
|
@ -12,9 +17,10 @@ data:
|
|||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_ratio: 0.2
|
||||
log:
|
||||
log_step: 2000
|
||||
plot: false
|
||||
sample: 1
|
||||
input_dim: 1
|
||||
batch_size: 64
|
||||
|
||||
model:
|
||||
cheb_order: 2
|
||||
embed_dim: 5
|
||||
|
|
@ -24,9 +30,8 @@ model:
|
|||
rnn_units: 64
|
||||
use_day: true
|
||||
use_week: true
|
||||
test:
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
num_nodes: 170
|
||||
horizon: 12
|
||||
train:
|
||||
batch_size: 64
|
||||
early_stop: true
|
||||
|
|
@ -42,3 +47,9 @@ train:
|
|||
real_value: true
|
||||
seed: 12
|
||||
weight_decay: 0
|
||||
debug: false
|
||||
output_dim: 1
|
||||
log_step: 2000
|
||||
plot: false
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
|
||||
|
||||
data:
|
||||
num_nodes: 358
|
||||
lag: 12
|
||||
|
|
|
|||
|
|
@ -1,4 +1,11 @@
|
|||
basic:
|
||||
dataset: "PEMSD8"
|
||||
mode : "train"
|
||||
device : "cuda:0"
|
||||
model: "EXP"
|
||||
|
||||
data:
|
||||
type: "PEMSD8"
|
||||
num_nodes: 170
|
||||
lag: 12
|
||||
horizon: 12
|
||||
|
|
@ -10,15 +17,25 @@ data:
|
|||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
input_dim: 1
|
||||
batch_size: 64
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
sample: 1
|
||||
cycle: 288
|
||||
|
||||
|
||||
model:
|
||||
horizon: 12
|
||||
hidden_dim: 64
|
||||
num_nodes: 170
|
||||
time_slots: 288
|
||||
embed_dim: 16
|
||||
batch_size: 64
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
in_len: 12
|
||||
|
||||
cycle_len: 288
|
||||
|
||||
train:
|
||||
loss_func: mae
|
||||
|
|
@ -35,11 +52,12 @@ train:
|
|||
grad_norm: False
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
debug: True
|
||||
output_dim: 1
|
||||
log_step: 2000
|
||||
plot: False
|
||||
|
||||
test:
|
||||
mae_thresh: null
|
||||
mape_thresh: 0.0
|
||||
|
||||
log:
|
||||
log_step: 200
|
||||
plot: False
|
||||
|
|
|
|||
|
|
@ -3,43 +3,38 @@ import yaml
|
|||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Model Training and Testing')
|
||||
parser.add_argument('--dataset', default='PEMSD8', type=str)
|
||||
parser.add_argument('--mode', default='train', type=str)
|
||||
parser.add_argument('--device', default='cuda:0', type=str, help='Indices of GPUs')
|
||||
parser.add_argument('--debug', default=False, type=eval)
|
||||
parser.add_argument('--model', default='GWN', type=str)
|
||||
parser.add_argument('--cuda', default=True, type=bool)
|
||||
parser.add_argument('--sample', default=1, type=int)
|
||||
parser.add_argument('--emb', default=12, type=int)
|
||||
parser.add_argument('--rnn', default=64, type=int)
|
||||
|
||||
|
||||
parser.add_argument('--config', type=str, required=True, help='Path to the configuration file')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load YAML configuration
|
||||
config_file = f'./config/{args.model}/{args.dataset}.yaml'
|
||||
with open(config_file, 'r') as file:
|
||||
if args.config:
|
||||
with open(args.config, 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
else:
|
||||
raise ValueError("Configuration file path must be provided using --config")
|
||||
|
||||
# Update configuration with command-line arguments
|
||||
# Merge 'basic' configuration into the root dictionary
|
||||
# config.update(config.get('basic', {}))
|
||||
|
||||
# Add adaptive configuration based on external commands
|
||||
if 'data' in config and 'type' in config['data']:
|
||||
config['data']['type'] = config['basic'].get('dataset', config['data']['type'])
|
||||
if 'model' in config and 'type' in config['model']:
|
||||
config['model']['type'] = config['basic'].get('model', config['model']['type'])
|
||||
if 'model' in config and 'rnn_units' in config['model']:
|
||||
config['model']['rnn_units'] = config['basic'].get('rnn', config['model']['rnn_units'])
|
||||
if 'model' in config and 'embed_dim' in config['model']:
|
||||
config['model']['embed_dim'] = config['basic'].get('emb', config['model']['embed_dim'])
|
||||
if 'data' in config and 'sample' in config['data']:
|
||||
config['data']['sample'] = config['basic'].get('sample', config['data']['sample'])
|
||||
if 'train' in config and 'device' in config['train']:
|
||||
config['train']['device'] = config['basic'].get('device', config['train']['device'])
|
||||
if 'train' in config and 'debug' in config['train']:
|
||||
config['train']['debug'] = config['basic'].get('debug', config['train']['debug'])
|
||||
if 'cuda' in config:
|
||||
config['cuda'] = config['basic'].get('cuda', config['cuda'])
|
||||
if 'mode' in config:
|
||||
config['mode'] = config['basic'].get('mode', config['mode'])
|
||||
|
||||
config['data']['type'] = args.dataset
|
||||
config['model']['type'] = args.model
|
||||
config['model']['rnn_units'] = args.rnn
|
||||
config['model']['embed_dim'] = args.emb
|
||||
config['data']['sample'] = args.sample
|
||||
config['data']['input_dim'] = config['model']['input_dim']
|
||||
config['data']['output_dim'] = config['model']['output_dim']
|
||||
config['data']['batch_size'] = config['train']['batch_size']
|
||||
config['model']['num_nodes'] = config['data']['num_nodes']
|
||||
config['model']['horizon'] = config['data']['horizon']
|
||||
config['model']['default_graph'] = config['data']['default_graph']
|
||||
config['train']['device'] = args.device
|
||||
config['train']['debug'] = args.debug
|
||||
config['train']['log_step'] = config['log']['log_step']
|
||||
config['train']['output_dim'] = config['model']['output_dim']
|
||||
config['train']['mae_thresh'] = config['test']['mae_thresh']
|
||||
config['train']['mape_thresh'] = config['test']['mape_thresh']
|
||||
config['cuda'] = args.cuda
|
||||
config['mode'] = args.mode
|
||||
config['device'] = args.device
|
||||
config['model']['device'] = config['device']
|
||||
return config
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ import h5py
|
|||
|
||||
|
||||
def get_dataloader(args, normalizer='std', single=True):
|
||||
data = load_st_dataset(args['type'], args['sample']) # 加载数据
|
||||
data = load_st_dataset(args) # 加载数据
|
||||
args = args['data']
|
||||
L, N, F = data.shape # 数据形状
|
||||
|
||||
# Step 1: data -> x,y
|
||||
|
|
@ -90,7 +91,9 @@ def get_dataloader(args, normalizer='std', single=True):
|
|||
|
||||
return train_dataloader, val_dataloader, test_dataloader, scaler
|
||||
|
||||
def load_st_dataset(dataset, sample):
|
||||
def load_st_dataset(config):
|
||||
dataset = config["basic"]["dataset"]
|
||||
sample = config["data"]["sample"]
|
||||
# output B, N, D
|
||||
match dataset:
|
||||
case 'PEMSD3':
|
||||
|
|
|
|||
|
|
@ -5,11 +5,12 @@ from dataloader.EXPdataloader import get_dataloader as EXP_loader
|
|||
from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader
|
||||
|
||||
def get_dataloader(config, normalizer, single):
|
||||
match config['model']['type']:
|
||||
case 'STGNCDE': return cde_loader(config['data'], normalizer, single)
|
||||
case 'STGNRDE': return nrde_loader(config['data'], normalizer, single)
|
||||
case 'DCRNN': return DCRNN_loader(config['data'], normalizer, single)
|
||||
case 'EXP': return EXP_loader(config['data'], normalizer, single)
|
||||
case _: return normal_loader(config['data'], normalizer, single)
|
||||
model_name = config["basic"]["model"]
|
||||
match model_name:
|
||||
case 'STGNCDE': return cde_loader(config, normalizer, single)
|
||||
case 'STGNRDE': return nrde_loader(config, normalizer, single)
|
||||
case 'DCRNN': return DCRNN_loader(config, normalizer, single)
|
||||
case 'EXP': return EXP_loader(config, normalizer, single)
|
||||
case _: return normal_loader(config, normalizer, single)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,23 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from model.model_selector import model_selector
|
||||
from lib.loss_function import masked_mae_loss
|
||||
import random
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
import os
|
||||
import yaml
|
||||
|
||||
def init_model(args, device):
|
||||
def init_model(args):
|
||||
device = args["device"]
|
||||
model = model_selector(args).to(device)
|
||||
# Initialize model parameters
|
||||
for p in model.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
else:
|
||||
nn.init.uniform_(p)
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"Model has {total_params} parameters")
|
||||
print(f"模型参数量: {total_params} ")
|
||||
return model
|
||||
|
||||
def init_optimizer(model, args):
|
||||
|
|
@ -38,12 +42,59 @@ def init_optimizer(model, args):
|
|||
return optimizer, lr_scheduler
|
||||
|
||||
def init_seed(seed):
|
||||
'''
|
||||
Disable cudnn to maximize reproducibility
|
||||
'''
|
||||
"""初始化种子,保证结果可复现"""
|
||||
torch.cuda.cudnn_enabled = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
def init_device(args):
|
||||
device_name = args['basic']['device']
|
||||
if 'model' not in args or not isinstance(args['model'], dict):
|
||||
args['model'] = {} # Ensure args['model'] is a dictionary
|
||||
match device_name:
|
||||
case 'mps':
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
args['device'] = 'mps'
|
||||
else:
|
||||
args['device'] = 'cpu'
|
||||
case device if device.startswith('cuda'):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(int(device.split(':')[1]))
|
||||
args['device'] = device
|
||||
else:
|
||||
args['device'] = 'cpu'
|
||||
case _:
|
||||
args['device'] = 'cpu'
|
||||
args['model']['device'] = args['device']
|
||||
return args
|
||||
|
||||
def init_loss(args, scaler):
|
||||
device = args['basic']['device']
|
||||
args = args['train']
|
||||
match args['loss_func']:
|
||||
case 'mask_mae':
|
||||
return masked_mae_loss(scaler, mask_value=None)
|
||||
case 'mae':
|
||||
return torch.nn.L1Loss().to(device)
|
||||
case 'mse':
|
||||
return torch.nn.MSELoss().to(device)
|
||||
case 'Huber':
|
||||
return torch.nn.HuberLoss().to(device)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported loss function: {args['loss_func']}")
|
||||
|
||||
def create_logs(args):
|
||||
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
args['train']['log_dir'] = os.path.join(current_dir, 'experiments', args['basic']['dataset'], current_time)
|
||||
config_filename = f"{args['basic']['dataset']}.yaml"
|
||||
os.makedirs(args['train']['log_dir'], exist_ok=True)
|
||||
config_content = yaml.safe_dump(args, default_flow_style=False)
|
||||
destination_path = os.path.join(args['train']['log_dir'], config_filename)
|
||||
# 将 args 保存为 YAML 文件
|
||||
with open(destination_path, 'w') as f:
|
||||
f.write(config_content)
|
||||
|
|
@ -23,30 +23,32 @@ from model.ST_SSL.ST_SSL import STSSLModel
|
|||
from model.STGNRDE.Make_model import make_model as make_nrde_model
|
||||
from model.STAWnet.STAWnet import STAWnet
|
||||
|
||||
def model_selector(model):
|
||||
match model['type']:
|
||||
case 'DDGCRN': return DDGCRN(model)
|
||||
case 'TWDGCN': return TWDGCN(model)
|
||||
case 'AGCRN': return AGCRN(model)
|
||||
case 'NLT': return HierAttnLstm(model)
|
||||
case 'STGNCDE': return make_model(model)
|
||||
case 'DSANET': return DSANet(model)
|
||||
case 'STGCN': return STGCNChebGraphConv(model)
|
||||
case 'DCRNN': return DCRNNModel(model)
|
||||
case 'ARIMA': return ARIMA(model)
|
||||
case 'TCN': return TemporalConvNet(model)
|
||||
case 'GWN': return gwnet(model)
|
||||
case 'STFGNN': return STFGNN(model)
|
||||
case 'STSGCN': return STSGCN(model)
|
||||
case 'STGODE': return ODEGCN(model)
|
||||
case 'PDG2SEQ': return PDG2Seq(model)
|
||||
case 'STMLP': return STMLP(model)
|
||||
case 'STIDGCN': return STIDGCN(model)
|
||||
case 'STID': return STID(model)
|
||||
case 'STAEFormer': return STAEformer(model)
|
||||
case 'EXP': return EXP(model)
|
||||
case 'MegaCRN': return MegaCRNModel(model)
|
||||
case 'ST_SSL': return STSSLModel(model)
|
||||
case 'STGNRDE': return make_nrde_model(model)
|
||||
case 'STAWnet': return STAWnet(model)
|
||||
def model_selector(config):
|
||||
model_name = config["basic"]["model"]
|
||||
model_config = config["model"]
|
||||
match model_name:
|
||||
case 'DDGCRN': return DDGCRN(model_config)
|
||||
case 'TWDGCN': return TWDGCN(model_config)
|
||||
case 'AGCRN': return AGCRN(model_config)
|
||||
case 'NLT': return HierAttnLstm(model_config)
|
||||
case 'STGNCDE': return make_model(model_config)
|
||||
case 'DSANET': return DSANet(model_config)
|
||||
case 'STGCN': return STGCNChebGraphConv(model_config)
|
||||
case 'DCRNN': return DCRNNModel(model_config)
|
||||
case 'ARIMA': return ARIMA(model_config)
|
||||
case 'TCN': return TemporalConvNet(model_config)
|
||||
case 'GWN': return gwnet(model_config)
|
||||
case 'STFGNN': return STFGNN(model_config)
|
||||
case 'STSGCN': return STSGCN(model_config)
|
||||
case 'STGODE': return ODEGCN(model_config)
|
||||
case 'PDG2SEQ': return PDG2Seq(model_config)
|
||||
case 'STMLP': return STMLP(model_config)
|
||||
case 'STIDGCN': return STIDGCN(model_config)
|
||||
case 'STID': return STID(model_config)
|
||||
case 'STAEFormer': return STAEformer(model_config)
|
||||
case 'EXP': return EXP(model_config)
|
||||
case 'MegaCRN': return MegaCRNModel(model_config)
|
||||
case 'ST_SSL': return STSSLModel(model_config)
|
||||
case 'STGNRDE': return make_nrde_model(model_config)
|
||||
case 'STAWnet': return STAWnet(model_config)
|
||||
|
||||
|
|
|
|||
94
run.py
94
run.py
|
|
@ -1,64 +1,21 @@
|
|||
import os
|
||||
|
||||
# 检查数据集完整性
|
||||
from lib.Download_data import check_and_download_data
|
||||
|
||||
data_complete = check_and_download_data()
|
||||
assert data_complete is not None, "数据集下载失败,请重试!"
|
||||
|
||||
import torch
|
||||
from datetime import datetime
|
||||
|
||||
# import time
|
||||
from config.args_parser import parse_args
|
||||
from lib.initializer import init_model, init_optimizer, init_seed
|
||||
from lib.loss_function import get_loss_function
|
||||
|
||||
import lib.initializer as init
|
||||
from dataloader.loader_selector import get_dataloader
|
||||
from trainer.trainer_selector import select_trainer
|
||||
import yaml
|
||||
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Set device (prefer MPS on macOS, then CUDA, else CPU)
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and args['device'] != 'cpu':
|
||||
args['device'] = 'mps'
|
||||
args['model']['device'] = args['device']
|
||||
elif torch.cuda.is_available() and args['device'] != 'cpu':
|
||||
torch.cuda.set_device(int(args['device'].split(':')[1]))
|
||||
args['model']['device'] = args['device']
|
||||
else:
|
||||
args['device'] = 'cpu'
|
||||
args['model']['device'] = args['device']
|
||||
init_seed(args['train']['seed'])
|
||||
# Initialize model
|
||||
model = init_model(args['model'], device=args['device'])
|
||||
|
||||
|
||||
|
||||
if args['mode'] == "benchmark":
|
||||
# 支持计算消耗分析,设置 mode为 benchmark
|
||||
import torch.profiler as profiler
|
||||
dummy_input = torch.randn((64, 12, args['model']['num_nodes'], 3), device=args['device'])
|
||||
min_val = dummy_input.min(dim=-1, keepdim=True)[0]
|
||||
max_val = dummy_input.max(dim=-1, keepdim=True)[0]
|
||||
|
||||
dummy_input = (dummy_input - min_val) / (max_val - min_val + 1e-6)
|
||||
with profiler.profile(
|
||||
activities=[
|
||||
profiler.ProfilerActivity.CPU,
|
||||
profiler.ProfilerActivity.CUDA
|
||||
],
|
||||
with_stack=True,
|
||||
profile_memory=True,
|
||||
record_shapes=True
|
||||
) as prof:
|
||||
out = model(dummy_input)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
return 0
|
||||
args = init.init_device(args)
|
||||
init.init_seed(args['train']['seed'])
|
||||
model = init.init_model(args)
|
||||
|
||||
# Load dataset
|
||||
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
|
||||
|
|
@ -67,49 +24,28 @@ def main():
|
|||
single=False
|
||||
)
|
||||
|
||||
# Initialize loss function
|
||||
loss = get_loss_function(args['train'], scaler)
|
||||
|
||||
# Initialize optimizer and learning rate scheduler
|
||||
optimizer, lr_scheduler = init_optimizer(model, args['train'])
|
||||
|
||||
# Configure log path
|
||||
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
args['train']['log_dir'] = os.path.join(current_dir, 'experiments', args['data']['type'], current_time)
|
||||
|
||||
# 配置文件路径
|
||||
config_filename = f"{args['data']['type']}.yaml"
|
||||
# config_path = os.path.join(current_dir, 'config', args['model']['type'], config_filename)
|
||||
# 确保日志目录存在
|
||||
os.makedirs(args['train']['log_dir'], exist_ok=True)
|
||||
|
||||
# 生成配置文件内容(将 args 转换为 YAML 格式)
|
||||
config_content = yaml.safe_dump(args, default_flow_style=False)
|
||||
|
||||
# 生成新的 YAML 文件名(例如:config.auto.yaml 或其他名称)
|
||||
destination_path = os.path.join(args['train']['log_dir'], config_filename)
|
||||
|
||||
# 将 args 保存为 YAML 文件
|
||||
with open(destination_path, 'w') as f:
|
||||
f.write(config_content)
|
||||
loss = init.init_loss(args, scaler)
|
||||
optimizer, lr_scheduler = init.init_optimizer(model, args['train'])
|
||||
init.create_logs(args)
|
||||
|
||||
# Start training or testing
|
||||
trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
||||
lr_scheduler, extra_data)
|
||||
|
||||
match args['mode']:
|
||||
match args['basic']['mode']:
|
||||
case 'train':
|
||||
trainer.train()
|
||||
case 'test':
|
||||
model.load_state_dict(torch.load(
|
||||
f"./pre-trained/{args['model']['type']}/{args['data']['type']}.pth",
|
||||
f"./pre-trained/{args['basic']['model']}/{args['basic']['dataset']}.pth",
|
||||
map_location=args['device'], weights_only=True))
|
||||
# print(f"Loaded saved model on {args['device']}")
|
||||
trainer.test(model.to(args['device']), trainer.args, test_loader, scaler, trainer.logger)
|
||||
trainer.test(model.to(args['basic']['device']), trainer.args, test_loader, scaler, trainer.logger)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported mode: {args['mode']}")
|
||||
raise ValueError(f"Unsupported mode: {args['basic']['mode']}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from lib.Download_data import check_and_download_data
|
||||
data_complete = check_and_download_data()
|
||||
assert data_complete is not None, "数据集下载失败,请重试!"
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -12,7 +12,10 @@ from lib.training_stats import TrainingStats
|
|||
|
||||
class Trainer:
|
||||
def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader,
|
||||
scaler, args, lr_scheduler=None):
|
||||
scaler, global_config, lr_scheduler=None):
|
||||
|
||||
self.device = global_config['basic']['device']
|
||||
train_config = global_config['train']
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.optimizer = optimizer
|
||||
|
|
@ -20,23 +23,23 @@ class Trainer:
|
|||
self.val_loader = val_loader
|
||||
self.test_loader = test_loader
|
||||
self.scaler = scaler
|
||||
self.args = args
|
||||
self.args = train_config
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.train_per_epoch = len(train_loader)
|
||||
self.val_per_epoch = len(val_loader) if val_loader else 0
|
||||
|
||||
# Paths for saving models and logs
|
||||
self.best_path = os.path.join(args['log_dir'], 'best_model.pth')
|
||||
self.best_test_path = os.path.join(args['log_dir'], 'best_test_model.pth')
|
||||
self.loss_figure_path = os.path.join(args['log_dir'], 'loss.png')
|
||||
self.best_path = os.path.join(train_config['log_dir'], 'best_model.pth')
|
||||
self.best_test_path = os.path.join(train_config['log_dir'], 'best_test_model.pth')
|
||||
self.loss_figure_path = os.path.join(train_config['log_dir'], 'loss.png')
|
||||
|
||||
# Initialize logger
|
||||
if not os.path.isdir(args['log_dir']) and not args['debug']:
|
||||
os.makedirs(args['log_dir'], exist_ok=True)
|
||||
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
||||
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
||||
if not os.path.isdir(train_config['log_dir']) and not train_config['debug']:
|
||||
os.makedirs(train_config['log_dir'], exist_ok=True)
|
||||
self.logger = get_logger(train_config['log_dir'], name=self.model.__class__.__name__, debug=train_config['debug'])
|
||||
self.logger.info(f"Experiment log path in: {train_config['log_dir']}")
|
||||
# Stats tracker
|
||||
self.stats = TrainingStats(device=args['device'])
|
||||
self.stats = TrainingStats(device=self.device)
|
||||
|
||||
def _run_epoch(self, epoch, dataloader, mode):
|
||||
is_train = (mode == 'train')
|
||||
|
|
@ -51,9 +54,9 @@ class Trainer:
|
|||
start_time = time.time()
|
||||
# unpack the new cycle_index
|
||||
data, target, cycle_index = batch
|
||||
data = data.to(self.args['device'])
|
||||
target = target.to(self.args['device'])
|
||||
cycle_index = cycle_index.to(self.args['device']).long()
|
||||
data = data.to(self.device)
|
||||
target = target.to(self.device)
|
||||
cycle_index = cycle_index.to(self.device).long()
|
||||
|
||||
# forward
|
||||
if is_train:
|
||||
|
|
@ -164,10 +167,13 @@ class Trainer:
|
|||
|
||||
@staticmethod
|
||||
def test(model, args, data_loader, scaler, logger, path=None):
|
||||
global_config = args
|
||||
device = global_config['basic']['device']
|
||||
args = global_config['train']
|
||||
if path:
|
||||
checkpoint = torch.load(path)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
model.to(args['device'])
|
||||
model.to(device)
|
||||
|
||||
model.eval()
|
||||
y_pred, y_true = [], []
|
||||
|
|
|
|||
|
|
@ -80,6 +80,8 @@ class TrainingStats:
|
|||
class Trainer:
|
||||
def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader,
|
||||
scaler, args, lr_scheduler=None):
|
||||
self.device = args['basic']['device']
|
||||
args = args['train']
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.optimizer = optimizer
|
||||
|
|
@ -104,7 +106,7 @@ class Trainer:
|
|||
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
||||
|
||||
# Stats tracker
|
||||
self.stats = TrainingStats(device=args['device'])
|
||||
self.stats = TrainingStats(device=self.device)
|
||||
|
||||
def _run_epoch(self, epoch, dataloader, mode):
|
||||
if mode == 'train':
|
||||
|
|
@ -123,7 +125,7 @@ class Trainer:
|
|||
start_time = time.time()
|
||||
|
||||
label = target[..., :self.args['output_dim']]
|
||||
output = self.model(data).to(self.args['device'])
|
||||
output = self.model(data).to(self.device)
|
||||
|
||||
if self.args['real_value']:
|
||||
output = self.scaler.inverse_transform(output)
|
||||
|
|
@ -153,6 +155,12 @@ class Trainer:
|
|||
self.logger.info(
|
||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||
|
||||
# 输出指标
|
||||
mae, rmse, mape = all_metrics(output, label, self.args['mae_thresh'], self.args['mape_thresh'])
|
||||
self.logger.info(
|
||||
f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:5.2f}"
|
||||
)
|
||||
|
||||
# 记录内存
|
||||
self.stats.record_memory_usage()
|
||||
|
||||
|
|
@ -232,19 +240,19 @@ class Trainer:
|
|||
if path:
|
||||
checkpoint = torch.load(path)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
model.to(args['device'])
|
||||
model.to(args['basic']['device'])
|
||||
|
||||
model.eval()
|
||||
y_pred, y_true = [], []
|
||||
|
||||
with torch.no_grad():
|
||||
for data, target in data_loader:
|
||||
label = target[..., :args['output_dim']]
|
||||
label = target[..., :args['train']['output_dim']]
|
||||
output = model(data)
|
||||
y_pred.append(output)
|
||||
y_true.append(label)
|
||||
|
||||
if args['real_value']:
|
||||
if args['train']['real_value']:
|
||||
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
|
||||
else:
|
||||
y_pred = torch.cat(y_pred, dim=0)
|
||||
|
|
@ -252,10 +260,10 @@ class Trainer:
|
|||
|
||||
for t in range(y_true.shape[1]):
|
||||
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
|
||||
args['mae_thresh'], args['mape_thresh'])
|
||||
args['train']['mae_thresh'], args['train']['mape_thresh'])
|
||||
logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
||||
|
||||
mae, rmse, mape = all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
|
||||
mae, rmse, mape = all_metrics(y_pred, y_true, args['train']['mae_thresh'], args['train']['mape_thresh'])
|
||||
logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -8,18 +8,19 @@ from trainer.E32Trainer import Trainer as EXP_Trainer
|
|||
|
||||
def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
||||
lr_scheduler, kwargs):
|
||||
match args['model']['type']:
|
||||
case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||
model_name = args['basic']['model']
|
||||
match model_name:
|
||||
case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
||||
lr_scheduler, kwargs[0], None)
|
||||
case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||
case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
||||
lr_scheduler, kwargs[0], None)
|
||||
case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||
case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
||||
lr_scheduler)
|
||||
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
||||
lr_scheduler)
|
||||
case 'STMLP': return STMLP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
||||
lr_scheduler)
|
||||
case 'EXP': return EXP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||
case 'EXP': return EXP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
||||
lr_scheduler)
|
||||
case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||
case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
||||
lr_scheduler)
|
||||
|
|
|
|||
Loading…
Reference in New Issue