Compare commits

...

2 Commits

7 changed files with 573 additions and 2 deletions

View File

@ -12,12 +12,14 @@ data:
add_day_in_week: True add_day_in_week: True
steps_per_day: 288 steps_per_day: 288
days_per_week: 7 days_per_week: 7
cycle: 288
model: model:
batch_size: 64 batch_size: 64
input_dim: 1 input_dim: 1
output_dim: 1 output_dim: 1
in_len: 12 in_len: 12
cycle_len: 288
train: train:
@ -36,6 +38,7 @@ train:
max_grad_norm: 5 max_grad_norm: 5
real_value: True real_value: True
test: test:
mae_thresh: null mae_thresh: null
mape_thresh: 0.0 mape_thresh: 0.0

213
dataloader/EXPdataloader.py Executable file
View File

@ -0,0 +1,213 @@
import numpy as np
import gc
import os
import torch
import h5py
from lib.normalization import normalize_dataset
def get_dataloader(args, normalizer='std', single=True):
# args should now include 'cycle'
data = load_st_dataset(args['type'], args['sample']) # [T, N, F]
L, N, F = data.shape
# compute cycle index
cycle_arr = np.arange(L) % args['cycle'] # length-L array
# Step 1: sliding windows for X and Y
x = add_window_x(data, args['lag'], args['horizon'], single)
y = add_window_y(data, args['lag'], args['horizon'], single)
# window count = M = L - lag - horizon + 1
M = x.shape[0]
# Step 2: time features
time_in_day = np.tile(
np.array([i % args['steps_per_day'] / args['steps_per_day'] for i in range(L)]),
(N, 1)
).T.reshape(L, N, 1)
day_in_week = np.tile(
np.array([(i // args['steps_per_day']) % args['days_per_week'] for i in range(L)]),
(N, 1)
).T.reshape(L, N, 1)
x_day = add_window_x(time_in_day, args['lag'], args['horizon'], single)
x_week = add_window_x(day_in_week, args['lag'], args['horizon'], single)
x = np.concatenate([x, x_day, x_week], axis=-1)
# del x_day, x_week
# gc.collect()
# Step 3: extract cycle index per window: take value at end of sequence
cycle_win = np.array([cycle_arr[i + args['lag']] for i in range(M)]) # shape [M]
# Step 4: split into train/val/test
if args['test_ratio'] > 1:
x_train, x_val, x_test = split_data_by_days(x, args['val_ratio'], args['test_ratio'])
y_train, y_val, y_test = split_data_by_days(y, args['val_ratio'], args['test_ratio'])
c_train, c_val, c_test = split_data_by_days(cycle_win, args['val_ratio'], args['test_ratio'])
else:
x_train, x_val, x_test = split_data_by_ratio(x, args['val_ratio'], args['test_ratio'])
y_train, y_val, y_test = split_data_by_ratio(y, args['val_ratio'], args['test_ratio'])
c_train, c_val, c_test = split_data_by_ratio(cycle_win, args['val_ratio'], args['test_ratio'])
# del x, y, cycle_win
# gc.collect()
# Step 5: normalization on X only
scaler = normalize_dataset(x_train[..., :args['input_dim']], normalizer, args['column_wise'])
x_train[..., :args['input_dim']] = scaler.transform(x_train[..., :args['input_dim']])
x_val[..., :args['input_dim']] = scaler.transform(x_val[..., :args['input_dim']])
x_test[..., :args['input_dim']] = scaler.transform(x_test[..., :args['input_dim']])
# add time features to Y
y_day = add_window_y(time_in_day, args['lag'], args['horizon'], single)
y_week = add_window_y(day_in_week, args['lag'], args['horizon'], single)
y = np.concatenate([y, y_day, y_week], axis=-1)
# del y_day, y_week, time_in_day, day_in_week
# gc.collect()
# split Y time-augmented
if args['test_ratio'] > 1:
y_train, y_val, y_test = split_data_by_days(y, args['val_ratio'], args['test_ratio'])
else:
y_train, y_val, y_test = split_data_by_ratio(y, args['val_ratio'], args['test_ratio'])
# del y
# Step 6: create dataloaders including cycle index
train_loader = data_loader_with_cycle(x_train, y_train, c_train, args['batch_size'], shuffle=True, drop_last=True)
val_loader = data_loader_with_cycle(x_val, y_val, c_val, args['batch_size'], shuffle=False, drop_last=True)
test_loader = data_loader_with_cycle(x_test, y_test, c_test, args['batch_size'], shuffle=False, drop_last=False)
return train_loader, val_loader, test_loader, scaler
def data_loader_with_cycle(X, Y, C, batch_size, shuffle=True, drop_last=True):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
X_t = torch.tensor(X, dtype=torch.float32, device=device)
Y_t = torch.tensor(Y, dtype=torch.float32, device=device)
C_t = torch.tensor(C, dtype=torch.long, device=device).unsqueeze(-1) # [B,1]
dataset = torch.utils.data.TensorDataset(X_t, Y_t, C_t)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
return loader
# Rest of the helper functions (load_st_dataset, split_data..., add_window_x/y) unchanged
def load_st_dataset(dataset, sample):
# output B, N, D
match dataset:
case 'PEMSD3':
data_path = os.path.join('./data/PEMS03/PEMS03.npz')
data = np.load(data_path)['data'][:, :, 0] # only the first dimension, traffic flow data
case 'PEMSD4':
data_path = os.path.join('./data/PEMS04/PEMS04.npz')
data = np.load(data_path)['data'][:, :, 0] # only the first dimension, traffic flow data
case 'PEMSD7':
data_path = os.path.join('./data/PEMS07/PEMS07.npz')
data = np.load(data_path)['data'][:, :, 0] # only the first dimension, traffic flow data
case 'PEMSD8':
data_path = os.path.join('./data/PEMS08/PEMS08.npz')
data = np.load(data_path)['data'][:, :, 0] # only the first dimension, traffic flow data
case 'PEMSD7(L)':
data_path = os.path.join('./data/PEMS07(L)/PEMS07L.npz')
data = np.load(data_path)['data'][:, :, 0] # only the first dimension, traffic flow data
case 'PEMSD7(M)':
data_path = os.path.join('./data/PEMS07(M)/V_228.csv')
data = np.genfromtxt(data_path, delimiter=',') # Read CSV directly with numpy
case 'METR-LA':
data_path = os.path.join('./data/METR-LA/METR.h5')
with h5py.File(data_path, 'r') as f: # Use h5py to handle HDF5 files without pandas
data = np.array(f['data'])
case 'BJ':
data_path = os.path.join('./data/BJ/BJ500.csv')
data = np.genfromtxt(data_path, delimiter=',', skip_header=1) # Skip header if present
case 'Hainan':
data_path = os.path.join('./data/Hainan/Hainan.npz')
data = np.load(data_path)['data'][:, :, 0]
case 'SD':
data_path = os.path.join('./data/SD/data.npz')
data = np.load(data_path)["data"][:, :, 0].astype(np.float32)
case _:
raise ValueError(f"Unsupported dataset: {dataset}")
# Ensure data shape compatibility
if len(data.shape) == 2:
data = np.expand_dims(data, axis=-1)
print('加载 %s 数据集中... ' % dataset)
return data[::sample]
def split_data_by_days(data, val_days, test_days, interval=30):
t = int((24 * 60) / interval)
test_data = data[-t * int(test_days):]
val_data = data[-t * int(test_days + val_days):-t * int(test_days)]
train_data = data[:-t * int(test_days + val_days)]
return train_data, val_data, test_data
def split_data_by_ratio(data, val_ratio, test_ratio):
data_len = data.shape[0]
test_data = data[-int(data_len * test_ratio):]
val_data = data[-int(data_len * (test_ratio + val_ratio)):-int(data_len * test_ratio)]
train_data = data[:-int(data_len * (test_ratio + val_ratio))]
return train_data, val_data, test_data
def data_loader(X, Y, batch_size, shuffle=True, drop_last=True):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
X = torch.tensor(X, dtype=torch.float32, device=device)
Y = torch.tensor(Y, dtype=torch.float32, device=device)
data = torch.utils.data.TensorDataset(X, Y)
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size,
shuffle=shuffle, drop_last=drop_last)
return dataloader
def add_window_x(data, window=3, horizon=1, single=False):
"""
Generate windowed X values from the input data.
:param data: Input data, shape [B, ...]
:param window: Size of the sliding window
:param horizon: Horizon size
:param single: If True, generate single-step windows, else multi-step
:return: X with shape [B, W, ...]
"""
length = len(data)
end_index = length - horizon - window + 1
x = [] # Sliding windows
index = 0
while index < end_index:
x.append(data[index:index + window])
index += 1
return np.array(x)
def add_window_y(data, window=3, horizon=1, single=False):
"""
Generate windowed Y values from the input data.
:param data: Input data, shape [B, ...]
:param window: Size of the sliding window
:param horizon: Horizon size
:param single: If True, generate single-step windows, else multi-step
:return: Y with shape [B, H, ...]
"""
length = len(data)
end_index = length - horizon - window + 1
y = [] # Horizon values
index = 0
while index < end_index:
if single:
y.append(data[index + window + horizon - 1:index + window + horizon])
else:
y.append(data[index + window:index + window + horizon])
index += 1
return np.array(y)
if __name__ == '__main__':
res = load_st_dataset('SD', 1)
k = 1

View File

@ -1,11 +1,13 @@
from dataloader.cde_loader.cdeDataloader import get_dataloader as cde_loader from dataloader.cde_loader.cdeDataloader import get_dataloader as cde_loader
from dataloader.PeMSDdataloader import get_dataloader as normal_loader from dataloader.PeMSDdataloader import get_dataloader as normal_loader
from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader
from dataloader.EXPdataloader import get_dataloader as EXP_loader
def get_dataloader(config, normalizer, single): def get_dataloader(config, normalizer, single):
match config['model']['type']: match config['model']['type']:
case 'STGNCDE': return cde_loader(config['data'], normalizer, single) case 'STGNCDE': return cde_loader(config['data'], normalizer, single)
case 'DCRNN': return DCRNN_loader(config['data'], normalizer, single) case 'DCRNN': return DCRNN_loader(config['data'], normalizer, single)
case 'EXP': return EXP_loader(config['data'], normalizer, single)
case _: return normal_loader(config['data'], normalizer, single) case _: return normal_loader(config['data'], normalizer, single)

168
model/EXP/EXP32.py Normal file
View File

@ -0,0 +1,168 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# ------------------------- CycleNet Component -------------------------
class RecurrentCycle(nn.Module):
"""Efficient cyclic data removal/addition."""
def __init__(self, cycle_len, channel_size):
super().__init__()
self.cycle_len = cycle_len
self.channel_size = channel_size
# 初始化周期缓存shape (cycle_len, channel_size)
self.data = nn.Parameter(torch.zeros(cycle_len, channel_size))
def forward(self, index, length):
# index: (B,), length: seq_len 或 pred_len
B = index.size(0)
# 生成 [0,1,...,length-1] 的偏移shape (1, length)
arange = torch.arange(length, device=index.device).unsqueeze(0)
# 对每条样本的起始 index 加 arange 并对 cycle_len 取模
idx = (index.unsqueeze(1) + arange) % self.cycle_len # (B, length)
# 返回对应的周期值 (B, length, channel_size)
return self.data[idx]
# ------------------------- Core Blocks -------------------------
class DynamicGraphConstructor(nn.Module):
def __init__(self, node_num, embed_dim):
super().__init__()
self.nodevec1 = nn.Parameter(torch.randn(node_num, embed_dim))
self.nodevec2 = nn.Parameter(torch.randn(node_num, embed_dim))
def forward(self):
adj = F.relu(torch.matmul(self.nodevec1, self.nodevec2.T))
return F.softmax(adj, dim=-1)
class GraphConvBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.theta = nn.Linear(input_dim, output_dim)
self.residual = (input_dim == output_dim)
if not self.residual:
self.res_proj = nn.Linear(input_dim, output_dim)
def forward(self, x, adj):
res = x
x = torch.matmul(adj, x)
x = self.theta(x)
if not self.residual:
res = self.res_proj(res)
return F.relu(x + res)
class MANBA_Block(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(input_dim)
def forward(self, x):
res = x
x_attn, _ = self.attn(x, x, x)
x = self.norm1(res + x_attn)
res2 = x
x_ffn = self.ffn(x)
return self.norm2(res2 + x_ffn)
class SandwichBlock(nn.Module):
def __init__(self, num_nodes, embed_dim, hidden_dim):
super().__init__()
self.manba1 = MANBA_Block(hidden_dim, hidden_dim * 2)
self.graph_constructor = DynamicGraphConstructor(num_nodes, embed_dim)
self.gc = GraphConvBlock(hidden_dim, hidden_dim)
self.manba2 = MANBA_Block(hidden_dim, hidden_dim * 2)
def forward(self, h):
h1 = self.manba1(h)
adj = self.graph_constructor()
h2 = self.gc(h1, adj)
return self.manba2(h2)
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dims, out_dim, activation=nn.ReLU):
super().__init__()
dims = [in_dim] + hidden_dims + [out_dim]
layers = []
for i in range(len(dims) - 2):
layers += [nn.Linear(dims[i], dims[i+1]), activation()]
layers.append(nn.Linear(dims[-2], dims[-1]))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
# ------------------------- EXP with CycleNet -------------------------
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args['horizon'] # 预测步长
self.output_dim = args['output_dim'] # 输出维度 (一般=1)
self.seq_len = args.get('in_len', 12) # 输入序列长度
self.hidden_dim = args.get('hidden_dim', 64)
self.num_nodes = args['num_nodes']
self.embed_dim = args.get('embed_dim', 16)
# 时间嵌入
self.time_slots = args.get('time_slots', 288)
self.time_embedding = nn.Embedding(self.time_slots, self.hidden_dim)
self.day_embedding = nn.Embedding(7, self.hidden_dim)
# CycleNet
self.cycleQueue = RecurrentCycle(cycle_len=args['cycle_len'], channel_size=self.num_nodes)
# 输入投影 (序列长度 -> 隐藏维度)
self.input_proj = MLP(self.seq_len, [self.hidden_dim], self.hidden_dim)
# 两层 Sandwich
self.sandwich1 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
self.sandwich2 = SandwichBlock(self.num_nodes, self.embed_dim, self.hidden_dim)
# 输出投影
self.out_proj = MLP(self.hidden_dim, [2*self.hidden_dim], self.horizon * self.output_dim)
def forward(self, x, cycle_index):
# x: (B, T, N, D>=3)
# 1) 拆流量和时间特征,保证丢掉通道维
x_flow = x[..., 0] # -> (B, T, N) or (B, T, N, 1) 如果之前切片错用了0:1
x_time = x[..., 1]
x_day = x[..., 2]
B, T, N = x_flow.shape
# DEBUG 打印(可删除)
# print("DEBUG x_flow.dim(), shape:", x_flow.dim(), x_flow.shape)
# 2) 去周期化
cyc = self.cycleQueue(cycle_index, T).squeeze(1) # (B, T, N)
x_flow = x_flow - cyc
# 3) 序列投影
h0 = x_flow.permute(0, 2, 1).reshape(B * N, T) # -> (B*N, T)
h0 = self.input_proj(h0).view(B, N, self.hidden_dim)
# 4) 加时间嵌入
t_idx = (x_time[:, -1] * (self.time_slots - 1)).long() # (B, N)
d_idx = x_day[:, -1].long() # (B, N)
h0 = h0 + self.time_embedding(t_idx) + self.day_embedding(d_idx)
# 5) Sandwich Blocks
h1 = self.sandwich1(h0) + h0
h2 = self.sandwich2(h1)
# 6) 输出投影并 reshape
out = self.out_proj(h2) # (B, N, H*O)
out = out.view(B, N, self.horizon, self.output_dim) # (B, N, H, O)
out = out.permute(0, 2, 1, 3) # (B, H, N, O)
# 加回周期
idx_out = (cycle_index + self.seq_len) % self.cycleQueue.cycle_len
cyc_out = self.cycleQueue(idx_out, self.horizon) # (B, 1, H, N)
# squeeze 掉第1维并 unsqueeze 最后一维
cyc_out = cyc_out.squeeze(1).unsqueeze(-1) # (B, H, N, 1)
# 加回周期分量
return out + cyc_out

View File

@ -15,7 +15,7 @@ from model.STGODE.STGODE import ODEGCN
from model.PDG2SEQ.PDG2Seqb import PDG2Seq from model.PDG2SEQ.PDG2Seqb import PDG2Seq
from model.STID.STID import STID from model.STID.STID import STID
from model.STAEFormer.STAEFormer import STAEformer from model.STAEFormer.STAEFormer import STAEformer
from model.EXP.EXP31 import EXP as EXP from model.EXP.EXP32 import EXP as EXP
def model_selector(model): def model_selector(model):
match model['type']: match model['type']:

185
trainer/E32Trainer.py Normal file
View File

@ -0,0 +1,185 @@
import math
import os
import time
import copy
from tqdm import tqdm
import torch
from lib.logger import get_logger
from lib.loss_function import all_metrics
class Trainer:
def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader,
scaler, args, lr_scheduler=None):
self.model = model
self.loss = loss
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.scaler = scaler
self.args = args
self.lr_scheduler = lr_scheduler
self.train_per_epoch = len(train_loader)
self.val_per_epoch = len(val_loader) if val_loader else 0
# Paths for saving models and logs
self.best_path = os.path.join(args['log_dir'], 'best_model.pth')
self.best_test_path = os.path.join(args['log_dir'], 'best_test_model.pth')
self.loss_figure_path = os.path.join(args['log_dir'], 'loss.png')
# Initialize logger
if not os.path.isdir(args['log_dir']) and not args['debug']:
os.makedirs(args['log_dir'], exist_ok=True)
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
self.logger.info(f"Experiment log path in: {args['log_dir']}")
def _run_epoch(self, epoch, dataloader, mode):
is_train = (mode == 'train')
self.model.train() if is_train else self.model.eval()
total_loss = 0.0
epoch_time = time.time()
with torch.set_grad_enabled(is_train), \
tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, batch in enumerate(dataloader):
# unpack the new cycle_index
data, target, cycle_index = batch
data = data.to(self.args['device'])
target = target.to(self.args['device'])
cycle_index = cycle_index.to(self.args['device']).long()
# forward
if is_train:
self.optimizer.zero_grad()
output = self.model(data, cycle_index)
else:
output = self.model(data, cycle_index)
# compute loss
label = target[..., :self.args['output_dim']]
if self.args['real_value']:
output = self.scaler.inverse_transform(output)
loss = self.loss(output, label)
# backward / step
if is_train:
loss.backward()
if self.args['grad_norm']:
torch.nn.utils.clip_grad_norm_(self.model.parameters(),
self.args['max_grad_norm'])
self.optimizer.step()
total_loss += loss.item()
# logging
if is_train and (batch_idx + 1) % self.args['log_step'] == 0:
self.logger.info(
f'Train Epoch {epoch}: {batch_idx+1}/{len(dataloader)} Loss: {loss.item():.6f}'
)
pbar.update(1)
pbar.set_postfix(loss=loss.item())
avg_loss = total_loss / len(dataloader)
self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
return avg_loss
def train_epoch(self, epoch):
return self._run_epoch(epoch, self.train_loader, 'train')
def val_epoch(self, epoch):
return self._run_epoch(epoch, self.val_loader or self.test_loader, 'val')
def test_epoch(self, epoch):
return self._run_epoch(epoch, self.test_loader, 'test')
def train(self):
best_model, best_test_model = None, None
best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0
self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch)
val_epoch_loss = self.val_epoch(epoch)
test_epoch_loss = self.test_epoch(epoch)
if train_epoch_loss > 1e6:
self.logger.warning('Gradient explosion detected. Ending...')
break
if val_epoch_loss < best_loss:
best_loss = val_epoch_loss
not_improved_count = 0
best_model = copy.deepcopy(self.model.state_dict())
self.logger.info('Best validation model saved!')
else:
not_improved_count += 1
if self.args['early_stop'] and not_improved_count == self.args['early_stop_patience']:
self.logger.info(
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
break
if test_epoch_loss < best_test_loss:
best_test_loss = test_epoch_loss
best_test_model = copy.deepcopy(self.model.state_dict())
if not self.args['debug']:
torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model)
self.logger.info("Testing on best validation model")
self.test(self.model, self.args, self.test_loader, self.scaler, self.logger)
self.model.load_state_dict(best_test_model)
self.logger.info("Testing on best test model")
self.test(self.model, self.args, self.test_loader, self.scaler, self.logger)
@staticmethod
def test(model, args, data_loader, scaler, logger, path=None):
if path:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict'])
model.to(args['device'])
model.eval()
y_pred, y_true = [], []
with torch.no_grad():
for data, target, cycle_index in data_loader:
label = target[..., :args['output_dim']]
output = model(data, cycle_index)
y_pred.append(output)
y_true.append(label)
if args['real_value']:
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
else:
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
# 你在这里需要把y_pred和y_true保存下来
# torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1]
# torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1]
for t in range(y_true.shape[1]):
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
args['mae_thresh'], args['mape_thresh'])
logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
mae, rmse, mape = all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
@staticmethod
def _compute_sampling_threshold(global_step, k):
return k / (k + math.exp(global_step / k))

View File

@ -2,7 +2,7 @@ from trainer.Trainer import Trainer
from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer
from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer
from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer
from trainer.EXP_trainer import Trainer as EXP_Trainer from trainer.E32Trainer import Trainer as EXP_Trainer
def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,