新增多个PEMS数据集配置文件,包含PEMSD3、PEMSD4、PEMSD7、PEMSD8及STAWnet、STGNRDE、ST_SSL模型的相关配置,优化模型训练参数设置。
This commit is contained in:
parent
b820b867fb
commit
29fd709c8c
|
|
@ -32,7 +32,7 @@ pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio
|
|||
python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0}
|
||||
```
|
||||
|
||||
- model_name: 目前支持:DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN
|
||||
- model_name: 目前支持:DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN、STAWnet
|
||||
- dataset_name目前支持:PEMSD3,PEMSD4、PEMSD7、PEMSD8
|
||||
- mode:train为训练模型,test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。
|
||||
- device: 支持'cpu'、'cuda:0'、‘cuda:1’ ... 取决于机器卡数
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ data:
|
|||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
column_wise: True
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
|
|
@ -21,26 +21,26 @@ model:
|
|||
max_diffusion_step: 2
|
||||
cl_decay_steps: 1000
|
||||
filter_type: dual_random_walk
|
||||
num_rnn_layers: 1
|
||||
num_rnn_layers: 2
|
||||
rnn_units: 64
|
||||
seq_len: 12
|
||||
use_curriculum_learning: True
|
||||
|
||||
train:
|
||||
loss_func: mae
|
||||
loss_func: mask_mae
|
||||
seed: 10
|
||||
batch_size: 64
|
||||
epochs: 300
|
||||
lr_init: 0.003
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: "5,20,40,70"
|
||||
lr_init: 0.001
|
||||
weight_decay: 0.0001
|
||||
lr_decay: True
|
||||
lr_decay_rate: 0.1
|
||||
lr_decay_step: "10,20,40,80"
|
||||
early_stop: True
|
||||
early_stop_patience: 15
|
||||
grad_norm: False
|
||||
early_stop_patience: 25
|
||||
grad_norm: True
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
real_value: False
|
||||
|
||||
test:
|
||||
mae_thresh: null
|
||||
|
|
|
|||
|
|
@ -2,10 +2,12 @@ from dataloader.cde_loader.cdeDataloader import get_dataloader as cde_loader
|
|||
from dataloader.PeMSDdataloader import get_dataloader as normal_loader
|
||||
from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader
|
||||
from dataloader.EXPdataloader import get_dataloader as EXP_loader
|
||||
from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader
|
||||
|
||||
def get_dataloader(config, normalizer, single):
|
||||
match config['model']['type']:
|
||||
case 'STGNCDE': return cde_loader(config['data'], normalizer, single)
|
||||
case 'STGNRDE': return nrde_loader(config['data'], normalizer, single)
|
||||
case 'DCRNN': return DCRNN_loader(config['data'], normalizer, single)
|
||||
case 'EXP': return EXP_loader(config['data'], normalizer, single)
|
||||
case _: return normal_loader(config['data'], normalizer, single)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
|
||||
def masked_mae_loss(scaler, mask_value):
|
||||
def loss(preds, labels):
|
||||
# 仅对预测反归一化;标签在数据管道中保持原始量纲
|
||||
if scaler:
|
||||
preds = scaler.inverse_transform(preds)
|
||||
labels = scaler.inverse_transform(labels)
|
||||
return mae_torch(pred=preds, true=labels, mask_value=mask_value)
|
||||
|
||||
return loss
|
||||
|
|
@ -11,7 +11,8 @@ def masked_mae_loss(scaler, mask_value):
|
|||
|
||||
def get_loss_function(args, scaler):
|
||||
if args['loss_func'] == 'mask_mae':
|
||||
return masked_mae_loss(scaler, mask_value=0.0).to(args['device'])
|
||||
# Return callable loss (no .to for function closures); disable masking by default
|
||||
return masked_mae_loss(scaler, mask_value=None)
|
||||
elif args['loss_func'] == 'mae':
|
||||
return torch.nn.L1Loss().to(args['device'])
|
||||
elif args['loss_func'] == 'mse':
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class LayerParams:
|
|||
|
||||
|
||||
class DCGRUCell(torch.nn.Module):
|
||||
def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, nonlinearity='tanh',
|
||||
def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, input_dim=None, nonlinearity='tanh',
|
||||
filter_type="laplacian", use_gc_for_ru=True):
|
||||
"""
|
||||
|
||||
|
|
@ -55,6 +55,7 @@ class DCGRUCell(torch.nn.Module):
|
|||
self._max_diffusion_step = max_diffusion_step
|
||||
self._supports = []
|
||||
self._use_gc_for_ru = use_gc_for_ru
|
||||
self._input_dim = input_dim # optional; if None, will be inferred at first forward
|
||||
supports = []
|
||||
if filter_type == "laplacian":
|
||||
supports.append(utils.calculate_scaled_laplacian(adj_mx, lambda_max=None))
|
||||
|
|
@ -71,6 +72,22 @@ class DCGRUCell(torch.nn.Module):
|
|||
self._fc_params = LayerParams(self, 'fc')
|
||||
self._gconv_params = LayerParams(self, 'gconv')
|
||||
|
||||
# Pre-create parameters if input_dim is known
|
||||
if self._input_dim is not None:
|
||||
num_matrices = len(self._supports) * self._max_diffusion_step + 1
|
||||
input_size = self._input_dim + self._num_units
|
||||
# FC weights/biases for RU gates (2 * num_units)
|
||||
self._fc_params.get_weights((input_size, 2 * self._num_units))
|
||||
self._fc_params.get_biases(2 * self._num_units, bias_start=1.0)
|
||||
# Optionally for candidate (num_units) if FC path is used
|
||||
self._fc_params.get_weights((input_size, self._num_units))
|
||||
self._fc_params.get_biases(self._num_units, bias_start=0.0)
|
||||
# GConv weights/biases for RU gates and candidate
|
||||
self._gconv_params.get_weights((input_size * num_matrices, 2 * self._num_units))
|
||||
self._gconv_params.get_biases(2 * self._num_units, bias_start=1.0)
|
||||
self._gconv_params.get_weights((input_size * num_matrices, self._num_units))
|
||||
self._gconv_params.get_biases(self._num_units, bias_start=0.0)
|
||||
|
||||
@staticmethod
|
||||
def _build_sparse_matrix(L):
|
||||
L = L.tocoo()
|
||||
|
|
|
|||
|
|
@ -15,6 +15,10 @@ class Seq2SeqAttrs:
|
|||
self.num_nodes = args.get('num_nodes', 1)
|
||||
self.num_rnn_layers = args.get('num_rnn_layers', 1)
|
||||
self.rnn_units = args.get('rnn_units')
|
||||
self.input_dim = args.get('input_dim', 1)
|
||||
self.output_dim = args.get('output_dim', 1)
|
||||
self.horizon = args.get('horizon', 12)
|
||||
self.seq_len = args.get('seq_len', 12)
|
||||
self.hidden_state_size = self.num_nodes * self.rnn_units
|
||||
|
||||
|
||||
|
|
@ -26,7 +30,7 @@ class EncoderModel(nn.Module, Seq2SeqAttrs):
|
|||
self.seq_len = args.get('seq_len') # for the encoder
|
||||
self.dcgru_layers = nn.ModuleList(
|
||||
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
|
||||
filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
|
||||
input_dim=self.input_dim, filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
|
||||
|
||||
def forward(self, inputs, hidden_state=None):
|
||||
"""
|
||||
|
|
@ -63,7 +67,7 @@ class DecoderModel(nn.Module, Seq2SeqAttrs):
|
|||
self.projection_layer = nn.Linear(self.rnn_units, self.output_dim)
|
||||
self.dcgru_layers = nn.ModuleList(
|
||||
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
|
||||
filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
|
||||
input_dim=self.output_dim, filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
|
||||
|
||||
def forward(self, inputs, hidden_state=None):
|
||||
"""
|
||||
|
|
@ -146,17 +150,19 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
|
|||
|
||||
def forward(self, inputs, labels=None):
|
||||
"""
|
||||
seq2seq forward pass 64 12 307 3
|
||||
:param inputs: shape (seq_len, batch_size, num_sensor * input_dim) 12 64 307 * 1
|
||||
:param labels: shape (horizon, batch_size, num_sensor * output) 12 64 307 1
|
||||
:param batches_seen: batches seen till now
|
||||
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
|
||||
seq2seq forward pass. inputs: [B, T, N, C]
|
||||
"""
|
||||
inputs = inputs[..., 0].permute(1, 0, 2)
|
||||
labels = labels[..., 0].permute(1, 0, 2)
|
||||
encoder_hidden_state = self.encoder(inputs)
|
||||
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=self.batch_seen)
|
||||
x = inputs[..., :self.input_dim]
|
||||
x = x.permute(1, 0, 2, 3).contiguous().view(self.seq_len, -1, self.num_nodes * self.input_dim)
|
||||
|
||||
y = None
|
||||
if labels is not None:
|
||||
y = labels[..., :self.output_dim]
|
||||
y = y.permute(1, 0, 2, 3).contiguous().view(self.horizon, -1, self.num_nodes * self.output_dim)
|
||||
|
||||
encoder_hidden_state = self.encoder(x)
|
||||
outputs = self.decoder(encoder_hidden_state, y, batches_seen=self.batch_seen)
|
||||
self.batch_seen += 1
|
||||
outputs = outputs.unsqueeze(dim=-1) # [12,64,307,1]
|
||||
outputs = outputs.permute(1, 0, 2, 3) # [64,12,307,1]
|
||||
outputs = outputs.view(self.horizon, -1, self.num_nodes, self.output_dim)
|
||||
outputs = outputs.permute(1, 0, 2, 3).contiguous()
|
||||
return outputs
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import numpy as np
|
|||
from lib.logger import get_logger
|
||||
from lib.metrics import All_Metrics
|
||||
from lib.TrainInits import print_model_parameters
|
||||
from lib.training_stats import TrainingStats
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(self, model, vector_field_f, vector_field_g, loss, optimizer, train_loader, val_loader, test_loader,
|
||||
|
|
@ -42,6 +43,8 @@ class Trainer(object):
|
|||
self.device = device
|
||||
self.times = times.to(self.device, dtype=torch.float)
|
||||
self.w = w
|
||||
# Stats tracker
|
||||
self.stats = TrainingStats(device=device)
|
||||
|
||||
def val_epoch(self, epoch, val_dataloader):
|
||||
self.model.eval()
|
||||
|
|
@ -49,6 +52,7 @@ class Trainer(object):
|
|||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(self.val_loader):
|
||||
start_time = time.time()
|
||||
# for iter, batch in enumerate(val_dataloader):
|
||||
batch = tuple(b.to(self.device, dtype=torch.float) for b in batch)
|
||||
*valid_coeffs, target = batch
|
||||
|
|
@ -61,8 +65,11 @@ class Trainer(object):
|
|||
#a whole batch of Metr_LA is filtered
|
||||
if not torch.isnan(loss):
|
||||
total_val_loss += loss.item()
|
||||
step_time = time.time() - start_time
|
||||
self.stats.record_step_time(step_time, 'val')
|
||||
val_loss = total_val_loss / len(val_dataloader)
|
||||
self.logger.info('**********Val Epoch {}: average Loss: {:.6f}'.format(epoch, val_loss))
|
||||
self.stats.record_memory_usage()
|
||||
if self.args.tensorboard:
|
||||
self.w.add_scalar(f'valid/loss', val_loss, epoch)
|
||||
return val_loss
|
||||
|
|
@ -73,6 +80,7 @@ class Trainer(object):
|
|||
# for batch_idx, (data, target) in enumerate(self.train_loader):
|
||||
# for batch_idx, (data, target) in enumerate(self.train_loader):
|
||||
for batch_idx, batch in enumerate(self.train_loader):
|
||||
start_time = time.time()
|
||||
batch = tuple(b.to(self.device, dtype=torch.float) for b in batch)
|
||||
*train_coeffs, target = batch
|
||||
# data = data[..., :self.args.input_dim]
|
||||
|
|
@ -101,6 +109,8 @@ class Trainer(object):
|
|||
if self.args.grad_norm:
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
|
||||
self.optimizer.step()
|
||||
step_time = time.time() - start_time
|
||||
self.stats.record_step_time(step_time, 'train')
|
||||
total_loss += loss.item()
|
||||
|
||||
#log information
|
||||
|
|
@ -109,6 +119,7 @@ class Trainer(object):
|
|||
epoch, batch_idx, self.train_per_epoch, loss.item()))
|
||||
train_epoch_loss = total_loss/self.train_per_epoch
|
||||
self.logger.info('**********Train Epoch {}: averaged Loss: {:.6f}'.format(epoch, train_epoch_loss))
|
||||
self.stats.record_memory_usage()
|
||||
if self.args.tensorboard:
|
||||
self.w.add_scalar(f'train/loss', train_epoch_loss, epoch)
|
||||
|
||||
|
|
@ -123,6 +134,7 @@ class Trainer(object):
|
|||
not_improved_count = 0
|
||||
train_loss_list = []
|
||||
val_loss_list = []
|
||||
self.stats.start_training()
|
||||
start_time = time.time()
|
||||
for epoch in range(1, self.args.epochs + 1):
|
||||
#epoch_time = time.time()
|
||||
|
|
@ -168,6 +180,13 @@ class Trainer(object):
|
|||
|
||||
training_time = time.time() - start_time
|
||||
self.logger.info("Total training time: {:.4f}min, best loss: {:.6f}".format((training_time / 60), best_loss))
|
||||
self.stats.end_training()
|
||||
self.stats.report(self.logger)
|
||||
try:
|
||||
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
self.logger.info(f"Trainable params: {total_params}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
#save the best model to file
|
||||
if not self.args.debug:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ from model.STIDGCN.STIDGCN import STIDGCN
|
|||
from model.STID.STID import STID
|
||||
from model.STAEFormer.STAEFormer import STAEformer
|
||||
from model.EXP.EXP32 import EXP as EXP
|
||||
from model.MegaCRN.MegaCRNModel import MegaCRNModel
|
||||
from model.ST_SSL.ST_SSL import STSSLModel
|
||||
from model.STGNRDE.Make_model import make_model as make_nrde_model
|
||||
from model.STAWnet.STAWnet import STAWnet
|
||||
|
||||
def model_selector(model):
|
||||
match model['type']:
|
||||
|
|
@ -41,4 +45,8 @@ def model_selector(model):
|
|||
case 'STID': return STID(model)
|
||||
case 'STAEFormer': return STAEformer(model)
|
||||
case 'EXP': return EXP(model)
|
||||
case 'MegaCRN': return MegaCRNModel(model)
|
||||
case 'ST_SSL': return STSSLModel(model)
|
||||
case 'STGNRDE': return make_nrde_model(model)
|
||||
case 'STAWnet': return STAWnet(model)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from tqdm import tqdm
|
|||
import torch
|
||||
from lib.logger import get_logger
|
||||
from lib.loss_function import all_metrics
|
||||
from lib.training_stats import TrainingStats
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
|
@ -34,6 +35,8 @@ class Trainer:
|
|||
os.makedirs(args['log_dir'], exist_ok=True)
|
||||
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
||||
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
||||
# Stats tracker
|
||||
self.stats = TrainingStats(device=args['device'])
|
||||
|
||||
def _run_epoch(self, epoch, dataloader, mode):
|
||||
if mode == 'train':
|
||||
|
|
@ -49,12 +52,13 @@ class Trainer:
|
|||
with torch.set_grad_enabled(optimizer_step):
|
||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||
for batch_idx, (data, target) in enumerate(dataloader):
|
||||
start_time = time.time()
|
||||
label = target[..., :self.args['output_dim']]
|
||||
# label = target[..., :self.args['output_dim']]
|
||||
output = self.model(data, labels=label.clone()).to(self.args['device'])
|
||||
|
||||
if self.args['real_value']:
|
||||
output = self.scaler.inverse_transform(output)
|
||||
label = self.scaler.inverse_transform(label)
|
||||
|
||||
loss = self.loss(output, label)
|
||||
if optimizer_step and self.optimizer is not None:
|
||||
|
|
@ -65,6 +69,8 @@ class Trainer:
|
|||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
||||
self.optimizer.step()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
self.stats.record_step_time(step_time, mode)
|
||||
total_loss += loss.item()
|
||||
|
||||
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
||||
|
|
@ -78,6 +84,8 @@ class Trainer:
|
|||
avg_loss = total_loss / len(dataloader)
|
||||
self.logger.info(
|
||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||
# 记录内存
|
||||
self.stats.record_memory_usage()
|
||||
return avg_loss
|
||||
|
||||
def train_epoch(self, epoch):
|
||||
|
|
@ -94,6 +102,7 @@ class Trainer:
|
|||
best_loss, best_test_loss = float('inf'), float('inf')
|
||||
not_improved_count = 0
|
||||
|
||||
self.stats.start_training()
|
||||
self.logger.info("Training process started")
|
||||
for epoch in range(1, self.args['epochs'] + 1):
|
||||
train_epoch_loss = self.train_epoch(epoch)
|
||||
|
|
@ -126,6 +135,14 @@ class Trainer:
|
|||
torch.save(best_test_model, self.best_test_path)
|
||||
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||
|
||||
# 输出统计与参数
|
||||
self.stats.end_training()
|
||||
self.stats.report(self.logger)
|
||||
try:
|
||||
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
self.logger.info(f"Trainable params: {total_params}")
|
||||
except Exception:
|
||||
pass
|
||||
self._finalize_training(best_model, best_test_model)
|
||||
|
||||
def _finalize_training(self, best_model, best_test_model):
|
||||
|
|
@ -154,11 +171,11 @@ class Trainer:
|
|||
y_pred.append(output)
|
||||
y_true.append(label)
|
||||
|
||||
if args['real_value']:
|
||||
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
|
||||
else:
|
||||
y_pred = torch.cat(y_pred, dim=0)
|
||||
y_pred = torch.cat(y_pred, dim=0)
|
||||
y_true = torch.cat(y_true, dim=0)
|
||||
if args['real_value']:
|
||||
y_pred = scaler.inverse_transform(y_pred)
|
||||
y_true = scaler.inverse_transform(y_true)
|
||||
|
||||
for t in range(y_true.shape[1]):
|
||||
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from tqdm import tqdm
|
|||
import torch
|
||||
from lib.logger import get_logger
|
||||
from lib.loss_function import all_metrics
|
||||
from lib.training_stats import TrainingStats
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
|
@ -34,6 +35,8 @@ class Trainer:
|
|||
os.makedirs(args['log_dir'], exist_ok=True)
|
||||
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
||||
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
||||
# Stats tracker
|
||||
self.stats = TrainingStats(device=args['device'])
|
||||
|
||||
def _run_epoch(self, epoch, dataloader, mode):
|
||||
is_train = (mode == 'train')
|
||||
|
|
@ -45,6 +48,7 @@ class Trainer:
|
|||
tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
start_time = time.time()
|
||||
# unpack the new cycle_index
|
||||
data, target, cycle_index = batch
|
||||
data = data.to(self.args['device'])
|
||||
|
|
@ -72,6 +76,8 @@ class Trainer:
|
|||
self.args['max_grad_norm'])
|
||||
self.optimizer.step()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
self.stats.record_step_time(step_time, mode)
|
||||
total_loss += loss.item()
|
||||
|
||||
# logging
|
||||
|
|
@ -86,6 +92,8 @@ class Trainer:
|
|||
avg_loss = total_loss / len(dataloader)
|
||||
self.logger.info(
|
||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||
# 记录内存
|
||||
self.stats.record_memory_usage()
|
||||
return avg_loss
|
||||
|
||||
def train_epoch(self, epoch):
|
||||
|
|
@ -102,6 +110,7 @@ class Trainer:
|
|||
best_loss, best_test_loss = float('inf'), float('inf')
|
||||
not_improved_count = 0
|
||||
|
||||
self.stats.start_training()
|
||||
self.logger.info("Training process started")
|
||||
for epoch in range(1, self.args['epochs'] + 1):
|
||||
train_epoch_loss = self.train_epoch(epoch)
|
||||
|
|
@ -134,6 +143,14 @@ class Trainer:
|
|||
torch.save(best_test_model, self.best_test_path)
|
||||
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||
|
||||
# 输出统计与参数
|
||||
self.stats.end_training()
|
||||
self.stats.report(self.logger)
|
||||
try:
|
||||
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
self.logger.info(f"Trainable params: {total_params}")
|
||||
except Exception:
|
||||
pass
|
||||
self._finalize_training(best_model, best_test_model)
|
||||
|
||||
def _finalize_training(self, best_model, best_test_model):
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from tqdm import tqdm
|
|||
import torch
|
||||
from lib.logger import get_logger
|
||||
from lib.loss_function import all_metrics
|
||||
from lib.training_stats import TrainingStats
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
|
@ -34,6 +35,8 @@ class Trainer:
|
|||
os.makedirs(args['log_dir'], exist_ok=True)
|
||||
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
||||
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
||||
# Stats tracker
|
||||
self.stats = TrainingStats(device=args['device'])
|
||||
|
||||
def _run_epoch(self, epoch, dataloader, mode):
|
||||
if mode == 'train':
|
||||
|
|
@ -49,6 +52,7 @@ class Trainer:
|
|||
with torch.set_grad_enabled(optimizer_step):
|
||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||
for batch_idx, (data, target) in enumerate(dataloader):
|
||||
start_time = time.time()
|
||||
label = target[..., :self.args['output_dim']]
|
||||
output = self.model(data).to(self.args['device'])
|
||||
|
||||
|
|
@ -64,6 +68,9 @@ class Trainer:
|
|||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
||||
self.optimizer.step()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
self.stats.record_step_time(step_time, mode)
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
||||
|
|
@ -77,6 +84,8 @@ class Trainer:
|
|||
avg_loss = total_loss / len(dataloader)
|
||||
self.logger.info(
|
||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||
# 记录内存
|
||||
self.stats.record_memory_usage()
|
||||
return avg_loss
|
||||
|
||||
def train_epoch(self, epoch):
|
||||
|
|
@ -93,6 +102,7 @@ class Trainer:
|
|||
best_loss, best_test_loss = float('inf'), float('inf')
|
||||
not_improved_count = 0
|
||||
|
||||
self.stats.start_training()
|
||||
self.logger.info("Training process started")
|
||||
for epoch in range(1, self.args['epochs'] + 1):
|
||||
train_epoch_loss = self.train_epoch(epoch)
|
||||
|
|
@ -124,7 +134,14 @@ class Trainer:
|
|||
torch.save(best_model, self.best_path)
|
||||
torch.save(best_test_model, self.best_test_path)
|
||||
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||
|
||||
# 输出统计与参数
|
||||
self.stats.end_training()
|
||||
self.stats.report(self.logger)
|
||||
try:
|
||||
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
self.logger.info(f"Trainable params: {total_params}")
|
||||
except Exception:
|
||||
pass
|
||||
self._finalize_training(best_model, best_test_model)
|
||||
|
||||
def _finalize_training(self, best_model, best_test_model):
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from tqdm import tqdm
|
|||
import torch
|
||||
from lib.logger import get_logger
|
||||
from lib.loss_function import all_metrics
|
||||
from lib.training_stats import TrainingStats
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
|
@ -35,6 +36,8 @@ class Trainer:
|
|||
os.makedirs(args['log_dir'], exist_ok=True)
|
||||
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
||||
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
||||
# Stats tracker
|
||||
self.stats = TrainingStats(device=args['device'])
|
||||
|
||||
def _run_epoch(self, epoch, dataloader, mode):
|
||||
if mode == 'train':
|
||||
|
|
@ -50,6 +53,7 @@ class Trainer:
|
|||
with torch.set_grad_enabled(optimizer_step):
|
||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||
for batch_idx, (data, target) in enumerate(dataloader):
|
||||
start_time = time.time()
|
||||
self.batches_seen += 1
|
||||
label = target[..., :self.args['output_dim']].clone()
|
||||
output = self.model(data, target, self.batches_seen).to(self.args['device'])
|
||||
|
|
@ -66,6 +70,9 @@ class Trainer:
|
|||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
||||
self.optimizer.step()
|
||||
|
||||
# record step time
|
||||
step_time = time.time() - start_time
|
||||
self.stats.record_step_time(step_time, mode)
|
||||
total_loss += loss.item()
|
||||
|
||||
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
||||
|
|
@ -79,6 +86,8 @@ class Trainer:
|
|||
avg_loss = total_loss / len(dataloader)
|
||||
self.logger.info(
|
||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||
# 记录内存
|
||||
self.stats.record_memory_usage()
|
||||
return avg_loss
|
||||
|
||||
def train_epoch(self, epoch):
|
||||
|
|
@ -95,6 +104,7 @@ class Trainer:
|
|||
best_loss, best_test_loss = float('inf'), float('inf')
|
||||
not_improved_count = 0
|
||||
|
||||
self.stats.start_training()
|
||||
self.logger.info("Training process started")
|
||||
for epoch in range(1, self.args['epochs'] + 1):
|
||||
train_epoch_loss = self.train_epoch(epoch)
|
||||
|
|
@ -127,6 +137,14 @@ class Trainer:
|
|||
torch.save(best_test_model, self.best_test_path)
|
||||
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||
|
||||
# 输出统计与参数
|
||||
self.stats.end_training()
|
||||
self.stats.report(self.logger)
|
||||
try:
|
||||
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
self.logger.info(f"Trainable params: {total_params}")
|
||||
except Exception:
|
||||
pass
|
||||
self._finalize_training(best_model, best_test_model)
|
||||
|
||||
def _finalize_training(self, best_model, best_test_model):
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from tqdm import tqdm
|
|||
from lib.logger import get_logger
|
||||
from lib.loss_function import all_metrics
|
||||
from model.STMLP.STMLP import STMLP
|
||||
from lib.training_stats import TrainingStats
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
|
@ -43,6 +44,8 @@ class Trainer:
|
|||
os.makedirs(self.pretrain_dir, exist_ok=True)
|
||||
self.logger = get_logger(self.args['log_dir'], name=self.model.__class__.__name__, debug=self.args['debug'])
|
||||
self.logger.info(f"Experiment log path in: {self.args['log_dir']}")
|
||||
# Stats tracker
|
||||
self.stats = TrainingStats(device=args['device'])
|
||||
|
||||
if self.args['teacher_stu']:
|
||||
self.tmodel = self.loadTeacher(args)
|
||||
|
|
@ -67,6 +70,7 @@ class Trainer:
|
|||
with torch.set_grad_enabled(optimizer_step):
|
||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||
for batch_idx, (data, target) in enumerate(dataloader):
|
||||
start_time = time.time()
|
||||
if self.args['teacher_stu']:
|
||||
label = target[..., :self.args['output_dim']]
|
||||
output, out_, _ = self.model(data)
|
||||
|
|
@ -100,6 +104,8 @@ class Trainer:
|
|||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
||||
self.optimizer.step()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
self.stats.record_step_time(step_time, mode)
|
||||
total_loss += loss.item()
|
||||
|
||||
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
||||
|
|
@ -113,6 +119,8 @@ class Trainer:
|
|||
avg_loss = total_loss / len(dataloader)
|
||||
self.logger.info(
|
||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||
# 记录内存
|
||||
self.stats.record_memory_usage()
|
||||
return avg_loss
|
||||
|
||||
def train_epoch(self, epoch):
|
||||
|
|
@ -129,6 +137,7 @@ class Trainer:
|
|||
best_loss, best_test_loss = float('inf'), float('inf')
|
||||
not_improved_count = 0
|
||||
|
||||
self.stats.start_training()
|
||||
self.logger.info("Training process started")
|
||||
for epoch in range(1, self.args['epochs'] + 1):
|
||||
train_epoch_loss = self.train_epoch(epoch)
|
||||
|
|
@ -165,6 +174,14 @@ class Trainer:
|
|||
torch.save(best_test_model, self.best_test_path)
|
||||
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||
|
||||
# 输出统计与参数
|
||||
self.stats.end_training()
|
||||
self.stats.report(self.logger)
|
||||
try:
|
||||
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
self.logger.info(f"Trainable params: {total_params}")
|
||||
except Exception:
|
||||
pass
|
||||
self._finalize_training(best_model, best_test_model)
|
||||
|
||||
def _finalize_training(self, best_model, best_test_model):
|
||||
|
|
|
|||
|
|
@ -211,6 +211,13 @@ class Trainer:
|
|||
|
||||
self._finalize_training(best_model, best_test_model)
|
||||
|
||||
# 输出参数量
|
||||
try:
|
||||
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
self.logger.info(f"Trainable params: {total_params}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _finalize_training(self, best_model, best_test_model):
|
||||
self.model.load_state_dict(best_model)
|
||||
self.logger.info("Testing on best validation model")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from tqdm import tqdm
|
|||
import torch
|
||||
from lib.logger import get_logger
|
||||
from lib.loss_function import all_metrics
|
||||
from lib.training_stats import TrainingStats
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
|
@ -34,6 +35,8 @@ class Trainer:
|
|||
os.makedirs(args['log_dir'], exist_ok=True)
|
||||
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
||||
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
||||
# Stats tracker
|
||||
self.stats = TrainingStats(device=args['device'])
|
||||
|
||||
def _run_epoch(self, epoch, dataloader, mode):
|
||||
if mode == 'train':
|
||||
|
|
@ -49,6 +52,7 @@ class Trainer:
|
|||
with torch.set_grad_enabled(optimizer_step):
|
||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||
for batch_idx, (data, target) in enumerate(dataloader):
|
||||
start_time = time.time()
|
||||
label = target[..., :self.args['output_dim']]
|
||||
output = self.model(data).to(self.args['device'])
|
||||
|
||||
|
|
@ -64,6 +68,8 @@ class Trainer:
|
|||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
||||
self.optimizer.step()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
self.stats.record_step_time(step_time, mode)
|
||||
total_loss += loss.item()
|
||||
|
||||
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
||||
|
|
@ -77,6 +83,8 @@ class Trainer:
|
|||
avg_loss = total_loss / len(dataloader)
|
||||
self.logger.info(
|
||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||
# 记录内存
|
||||
self.stats.record_memory_usage()
|
||||
return avg_loss
|
||||
|
||||
def train_epoch(self, epoch):
|
||||
|
|
@ -93,6 +101,7 @@ class Trainer:
|
|||
best_loss, best_test_loss = float('inf'), float('inf')
|
||||
not_improved_count = 0
|
||||
|
||||
self.stats.start_training()
|
||||
self.logger.info("Training process started")
|
||||
for epoch in range(1, self.args['epochs'] + 1):
|
||||
train_epoch_loss = self.train_epoch(epoch)
|
||||
|
|
@ -125,6 +134,14 @@ class Trainer:
|
|||
torch.save(best_test_model, self.best_test_path)
|
||||
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||
|
||||
# 输出统计与参数
|
||||
self.stats.end_training()
|
||||
self.stats.report(self.logger)
|
||||
try:
|
||||
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
self.logger.info(f"Trainable params: {total_params}")
|
||||
except Exception:
|
||||
pass
|
||||
self._finalize_training(best_model, best_test_model)
|
||||
|
||||
def _finalize_training(self, best_model, best_test_model):
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from tqdm import tqdm
|
|||
import torch
|
||||
from lib.logger import get_logger
|
||||
from lib.loss_function import all_metrics
|
||||
from lib.training_stats import TrainingStats
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
|
@ -35,6 +36,8 @@ class Trainer:
|
|||
os.makedirs(args['log_dir'], exist_ok=True)
|
||||
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
||||
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
||||
# Stats tracker
|
||||
self.stats = TrainingStats(device=args['device'])
|
||||
self.times = times.to(self.device, dtype=torch.float)
|
||||
self.w = w
|
||||
|
||||
|
|
@ -52,6 +55,7 @@ class Trainer:
|
|||
with torch.set_grad_enabled(optimizer_step):
|
||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
start_time = time.time()
|
||||
batch = tuple(b.to(self.device, dtype=torch.float) for b in batch)
|
||||
*train_coeffs, target = batch
|
||||
label = target[..., :self.args['output_dim']]
|
||||
|
|
@ -69,6 +73,8 @@ class Trainer:
|
|||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
||||
self.optimizer.step()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
self.stats.record_step_time(step_time, mode)
|
||||
total_loss += loss.item()
|
||||
|
||||
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
||||
|
|
@ -82,6 +88,8 @@ class Trainer:
|
|||
avg_loss = total_loss / len(dataloader)
|
||||
self.logger.info(
|
||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||
# 记录内存
|
||||
self.stats.record_memory_usage()
|
||||
return avg_loss
|
||||
def train_epoch(self, epoch):
|
||||
return self._run_epoch(epoch, self.train_loader, 'train')
|
||||
|
|
@ -97,6 +105,7 @@ class Trainer:
|
|||
best_loss, best_test_loss = float('inf'), float('inf')
|
||||
not_improved_count = 0
|
||||
|
||||
self.stats.start_training()
|
||||
self.logger.info("Training process started")
|
||||
for epoch in range(1, self.args['epochs'] + 1):
|
||||
train_epoch_loss = self.train_epoch(epoch)
|
||||
|
|
@ -129,6 +138,14 @@ class Trainer:
|
|||
torch.save(best_test_model, self.best_test_path)
|
||||
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||
|
||||
# 输出统计与参数
|
||||
self.stats.end_training()
|
||||
self.stats.report(self.logger)
|
||||
try:
|
||||
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
self.logger.info(f"Trainable params: {total_params}")
|
||||
except Exception:
|
||||
pass
|
||||
self._finalize_training(best_model, best_test_model)
|
||||
|
||||
def _finalize_training(self, best_model, best_test_model):
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader
|
|||
match args['model']['type']:
|
||||
case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||
lr_scheduler, kwargs[0], None)
|
||||
case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||
lr_scheduler, kwargs[0], None)
|
||||
case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||
lr_scheduler)
|
||||
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||
|
|
|
|||
Loading…
Reference in New Issue