新增多个PEMS数据集配置文件,包含PEMSD3、PEMSD4、PEMSD7、PEMSD8及STAWnet、STGNRDE、ST_SSL模型的相关配置,优化模型训练参数设置。
This commit is contained in:
parent
b820b867fb
commit
29fd709c8c
|
|
@ -32,7 +32,7 @@ pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio
|
||||||
python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0}
|
python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0}
|
||||||
```
|
```
|
||||||
|
|
||||||
- model_name: 目前支持:DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN
|
- model_name: 目前支持:DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN、STAWnet
|
||||||
- dataset_name目前支持:PEMSD3,PEMSD4、PEMSD7、PEMSD8
|
- dataset_name目前支持:PEMSD3,PEMSD4、PEMSD7、PEMSD8
|
||||||
- mode:train为训练模型,test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。
|
- mode:train为训练模型,test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。
|
||||||
- device: 支持'cpu'、'cuda:0'、‘cuda:1’ ... 取决于机器卡数
|
- device: 支持'cpu'、'cuda:0'、‘cuda:1’ ... 取决于机器卡数
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ data:
|
||||||
test_ratio: 0.2
|
test_ratio: 0.2
|
||||||
tod: False
|
tod: False
|
||||||
normalizer: std
|
normalizer: std
|
||||||
column_wise: False
|
column_wise: True
|
||||||
default_graph: True
|
default_graph: True
|
||||||
add_time_in_day: True
|
add_time_in_day: True
|
||||||
add_day_in_week: True
|
add_day_in_week: True
|
||||||
|
|
@ -21,26 +21,26 @@ model:
|
||||||
max_diffusion_step: 2
|
max_diffusion_step: 2
|
||||||
cl_decay_steps: 1000
|
cl_decay_steps: 1000
|
||||||
filter_type: dual_random_walk
|
filter_type: dual_random_walk
|
||||||
num_rnn_layers: 1
|
num_rnn_layers: 2
|
||||||
rnn_units: 64
|
rnn_units: 64
|
||||||
seq_len: 12
|
seq_len: 12
|
||||||
use_curriculum_learning: True
|
use_curriculum_learning: True
|
||||||
|
|
||||||
train:
|
train:
|
||||||
loss_func: mae
|
loss_func: mask_mae
|
||||||
seed: 10
|
seed: 10
|
||||||
batch_size: 64
|
batch_size: 64
|
||||||
epochs: 300
|
epochs: 300
|
||||||
lr_init: 0.003
|
lr_init: 0.001
|
||||||
weight_decay: 0
|
weight_decay: 0.0001
|
||||||
lr_decay: False
|
lr_decay: True
|
||||||
lr_decay_rate: 0.3
|
lr_decay_rate: 0.1
|
||||||
lr_decay_step: "5,20,40,70"
|
lr_decay_step: "10,20,40,80"
|
||||||
early_stop: True
|
early_stop: True
|
||||||
early_stop_patience: 15
|
early_stop_patience: 25
|
||||||
grad_norm: False
|
grad_norm: True
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
real_value: True
|
real_value: False
|
||||||
|
|
||||||
test:
|
test:
|
||||||
mae_thresh: null
|
mae_thresh: null
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,12 @@ from dataloader.cde_loader.cdeDataloader import get_dataloader as cde_loader
|
||||||
from dataloader.PeMSDdataloader import get_dataloader as normal_loader
|
from dataloader.PeMSDdataloader import get_dataloader as normal_loader
|
||||||
from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader
|
from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader
|
||||||
from dataloader.EXPdataloader import get_dataloader as EXP_loader
|
from dataloader.EXPdataloader import get_dataloader as EXP_loader
|
||||||
|
from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader
|
||||||
|
|
||||||
def get_dataloader(config, normalizer, single):
|
def get_dataloader(config, normalizer, single):
|
||||||
match config['model']['type']:
|
match config['model']['type']:
|
||||||
case 'STGNCDE': return cde_loader(config['data'], normalizer, single)
|
case 'STGNCDE': return cde_loader(config['data'], normalizer, single)
|
||||||
|
case 'STGNRDE': return nrde_loader(config['data'], normalizer, single)
|
||||||
case 'DCRNN': return DCRNN_loader(config['data'], normalizer, single)
|
case 'DCRNN': return DCRNN_loader(config['data'], normalizer, single)
|
||||||
case 'EXP': return EXP_loader(config['data'], normalizer, single)
|
case 'EXP': return EXP_loader(config['data'], normalizer, single)
|
||||||
case _: return normal_loader(config['data'], normalizer, single)
|
case _: return normal_loader(config['data'], normalizer, single)
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
|
|
||||||
def masked_mae_loss(scaler, mask_value):
|
def masked_mae_loss(scaler, mask_value):
|
||||||
def loss(preds, labels):
|
def loss(preds, labels):
|
||||||
|
# 仅对预测反归一化;标签在数据管道中保持原始量纲
|
||||||
if scaler:
|
if scaler:
|
||||||
preds = scaler.inverse_transform(preds)
|
preds = scaler.inverse_transform(preds)
|
||||||
labels = scaler.inverse_transform(labels)
|
|
||||||
return mae_torch(pred=preds, true=labels, mask_value=mask_value)
|
return mae_torch(pred=preds, true=labels, mask_value=mask_value)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
@ -11,7 +11,8 @@ def masked_mae_loss(scaler, mask_value):
|
||||||
|
|
||||||
def get_loss_function(args, scaler):
|
def get_loss_function(args, scaler):
|
||||||
if args['loss_func'] == 'mask_mae':
|
if args['loss_func'] == 'mask_mae':
|
||||||
return masked_mae_loss(scaler, mask_value=0.0).to(args['device'])
|
# Return callable loss (no .to for function closures); disable masking by default
|
||||||
|
return masked_mae_loss(scaler, mask_value=None)
|
||||||
elif args['loss_func'] == 'mae':
|
elif args['loss_func'] == 'mae':
|
||||||
return torch.nn.L1Loss().to(args['device'])
|
return torch.nn.L1Loss().to(args['device'])
|
||||||
elif args['loss_func'] == 'mse':
|
elif args['loss_func'] == 'mse':
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ class LayerParams:
|
||||||
|
|
||||||
|
|
||||||
class DCGRUCell(torch.nn.Module):
|
class DCGRUCell(torch.nn.Module):
|
||||||
def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, nonlinearity='tanh',
|
def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, input_dim=None, nonlinearity='tanh',
|
||||||
filter_type="laplacian", use_gc_for_ru=True):
|
filter_type="laplacian", use_gc_for_ru=True):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -55,6 +55,7 @@ class DCGRUCell(torch.nn.Module):
|
||||||
self._max_diffusion_step = max_diffusion_step
|
self._max_diffusion_step = max_diffusion_step
|
||||||
self._supports = []
|
self._supports = []
|
||||||
self._use_gc_for_ru = use_gc_for_ru
|
self._use_gc_for_ru = use_gc_for_ru
|
||||||
|
self._input_dim = input_dim # optional; if None, will be inferred at first forward
|
||||||
supports = []
|
supports = []
|
||||||
if filter_type == "laplacian":
|
if filter_type == "laplacian":
|
||||||
supports.append(utils.calculate_scaled_laplacian(adj_mx, lambda_max=None))
|
supports.append(utils.calculate_scaled_laplacian(adj_mx, lambda_max=None))
|
||||||
|
|
@ -71,6 +72,22 @@ class DCGRUCell(torch.nn.Module):
|
||||||
self._fc_params = LayerParams(self, 'fc')
|
self._fc_params = LayerParams(self, 'fc')
|
||||||
self._gconv_params = LayerParams(self, 'gconv')
|
self._gconv_params = LayerParams(self, 'gconv')
|
||||||
|
|
||||||
|
# Pre-create parameters if input_dim is known
|
||||||
|
if self._input_dim is not None:
|
||||||
|
num_matrices = len(self._supports) * self._max_diffusion_step + 1
|
||||||
|
input_size = self._input_dim + self._num_units
|
||||||
|
# FC weights/biases for RU gates (2 * num_units)
|
||||||
|
self._fc_params.get_weights((input_size, 2 * self._num_units))
|
||||||
|
self._fc_params.get_biases(2 * self._num_units, bias_start=1.0)
|
||||||
|
# Optionally for candidate (num_units) if FC path is used
|
||||||
|
self._fc_params.get_weights((input_size, self._num_units))
|
||||||
|
self._fc_params.get_biases(self._num_units, bias_start=0.0)
|
||||||
|
# GConv weights/biases for RU gates and candidate
|
||||||
|
self._gconv_params.get_weights((input_size * num_matrices, 2 * self._num_units))
|
||||||
|
self._gconv_params.get_biases(2 * self._num_units, bias_start=1.0)
|
||||||
|
self._gconv_params.get_weights((input_size * num_matrices, self._num_units))
|
||||||
|
self._gconv_params.get_biases(self._num_units, bias_start=0.0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_sparse_matrix(L):
|
def _build_sparse_matrix(L):
|
||||||
L = L.tocoo()
|
L = L.tocoo()
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,10 @@ class Seq2SeqAttrs:
|
||||||
self.num_nodes = args.get('num_nodes', 1)
|
self.num_nodes = args.get('num_nodes', 1)
|
||||||
self.num_rnn_layers = args.get('num_rnn_layers', 1)
|
self.num_rnn_layers = args.get('num_rnn_layers', 1)
|
||||||
self.rnn_units = args.get('rnn_units')
|
self.rnn_units = args.get('rnn_units')
|
||||||
|
self.input_dim = args.get('input_dim', 1)
|
||||||
|
self.output_dim = args.get('output_dim', 1)
|
||||||
|
self.horizon = args.get('horizon', 12)
|
||||||
|
self.seq_len = args.get('seq_len', 12)
|
||||||
self.hidden_state_size = self.num_nodes * self.rnn_units
|
self.hidden_state_size = self.num_nodes * self.rnn_units
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -26,7 +30,7 @@ class EncoderModel(nn.Module, Seq2SeqAttrs):
|
||||||
self.seq_len = args.get('seq_len') # for the encoder
|
self.seq_len = args.get('seq_len') # for the encoder
|
||||||
self.dcgru_layers = nn.ModuleList(
|
self.dcgru_layers = nn.ModuleList(
|
||||||
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
|
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
|
||||||
filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
|
input_dim=self.input_dim, filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
|
||||||
|
|
||||||
def forward(self, inputs, hidden_state=None):
|
def forward(self, inputs, hidden_state=None):
|
||||||
"""
|
"""
|
||||||
|
|
@ -63,7 +67,7 @@ class DecoderModel(nn.Module, Seq2SeqAttrs):
|
||||||
self.projection_layer = nn.Linear(self.rnn_units, self.output_dim)
|
self.projection_layer = nn.Linear(self.rnn_units, self.output_dim)
|
||||||
self.dcgru_layers = nn.ModuleList(
|
self.dcgru_layers = nn.ModuleList(
|
||||||
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
|
[DCGRUCell(self.rnn_units, adj_mx, self.max_diffusion_step, self.num_nodes,
|
||||||
filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
|
input_dim=self.output_dim, filter_type=self.filter_type) for _ in range(self.num_rnn_layers)])
|
||||||
|
|
||||||
def forward(self, inputs, hidden_state=None):
|
def forward(self, inputs, hidden_state=None):
|
||||||
"""
|
"""
|
||||||
|
|
@ -146,17 +150,19 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
|
||||||
|
|
||||||
def forward(self, inputs, labels=None):
|
def forward(self, inputs, labels=None):
|
||||||
"""
|
"""
|
||||||
seq2seq forward pass 64 12 307 3
|
seq2seq forward pass. inputs: [B, T, N, C]
|
||||||
:param inputs: shape (seq_len, batch_size, num_sensor * input_dim) 12 64 307 * 1
|
|
||||||
:param labels: shape (horizon, batch_size, num_sensor * output) 12 64 307 1
|
|
||||||
:param batches_seen: batches seen till now
|
|
||||||
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
|
|
||||||
"""
|
"""
|
||||||
inputs = inputs[..., 0].permute(1, 0, 2)
|
x = inputs[..., :self.input_dim]
|
||||||
labels = labels[..., 0].permute(1, 0, 2)
|
x = x.permute(1, 0, 2, 3).contiguous().view(self.seq_len, -1, self.num_nodes * self.input_dim)
|
||||||
encoder_hidden_state = self.encoder(inputs)
|
|
||||||
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=self.batch_seen)
|
y = None
|
||||||
|
if labels is not None:
|
||||||
|
y = labels[..., :self.output_dim]
|
||||||
|
y = y.permute(1, 0, 2, 3).contiguous().view(self.horizon, -1, self.num_nodes * self.output_dim)
|
||||||
|
|
||||||
|
encoder_hidden_state = self.encoder(x)
|
||||||
|
outputs = self.decoder(encoder_hidden_state, y, batches_seen=self.batch_seen)
|
||||||
self.batch_seen += 1
|
self.batch_seen += 1
|
||||||
outputs = outputs.unsqueeze(dim=-1) # [12,64,307,1]
|
outputs = outputs.view(self.horizon, -1, self.num_nodes, self.output_dim)
|
||||||
outputs = outputs.permute(1, 0, 2, 3) # [64,12,307,1]
|
outputs = outputs.permute(1, 0, 2, 3).contiguous()
|
||||||
return outputs
|
return outputs
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import numpy as np
|
||||||
from lib.logger import get_logger
|
from lib.logger import get_logger
|
||||||
from lib.metrics import All_Metrics
|
from lib.metrics import All_Metrics
|
||||||
from lib.TrainInits import print_model_parameters
|
from lib.TrainInits import print_model_parameters
|
||||||
|
from lib.training_stats import TrainingStats
|
||||||
|
|
||||||
class Trainer(object):
|
class Trainer(object):
|
||||||
def __init__(self, model, vector_field_f, vector_field_g, loss, optimizer, train_loader, val_loader, test_loader,
|
def __init__(self, model, vector_field_f, vector_field_g, loss, optimizer, train_loader, val_loader, test_loader,
|
||||||
|
|
@ -42,6 +43,8 @@ class Trainer(object):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.times = times.to(self.device, dtype=torch.float)
|
self.times = times.to(self.device, dtype=torch.float)
|
||||||
self.w = w
|
self.w = w
|
||||||
|
# Stats tracker
|
||||||
|
self.stats = TrainingStats(device=device)
|
||||||
|
|
||||||
def val_epoch(self, epoch, val_dataloader):
|
def val_epoch(self, epoch, val_dataloader):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
@ -49,6 +52,7 @@ class Trainer(object):
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, batch in enumerate(self.val_loader):
|
for batch_idx, batch in enumerate(self.val_loader):
|
||||||
|
start_time = time.time()
|
||||||
# for iter, batch in enumerate(val_dataloader):
|
# for iter, batch in enumerate(val_dataloader):
|
||||||
batch = tuple(b.to(self.device, dtype=torch.float) for b in batch)
|
batch = tuple(b.to(self.device, dtype=torch.float) for b in batch)
|
||||||
*valid_coeffs, target = batch
|
*valid_coeffs, target = batch
|
||||||
|
|
@ -61,8 +65,11 @@ class Trainer(object):
|
||||||
#a whole batch of Metr_LA is filtered
|
#a whole batch of Metr_LA is filtered
|
||||||
if not torch.isnan(loss):
|
if not torch.isnan(loss):
|
||||||
total_val_loss += loss.item()
|
total_val_loss += loss.item()
|
||||||
|
step_time = time.time() - start_time
|
||||||
|
self.stats.record_step_time(step_time, 'val')
|
||||||
val_loss = total_val_loss / len(val_dataloader)
|
val_loss = total_val_loss / len(val_dataloader)
|
||||||
self.logger.info('**********Val Epoch {}: average Loss: {:.6f}'.format(epoch, val_loss))
|
self.logger.info('**********Val Epoch {}: average Loss: {:.6f}'.format(epoch, val_loss))
|
||||||
|
self.stats.record_memory_usage()
|
||||||
if self.args.tensorboard:
|
if self.args.tensorboard:
|
||||||
self.w.add_scalar(f'valid/loss', val_loss, epoch)
|
self.w.add_scalar(f'valid/loss', val_loss, epoch)
|
||||||
return val_loss
|
return val_loss
|
||||||
|
|
@ -73,6 +80,7 @@ class Trainer(object):
|
||||||
# for batch_idx, (data, target) in enumerate(self.train_loader):
|
# for batch_idx, (data, target) in enumerate(self.train_loader):
|
||||||
# for batch_idx, (data, target) in enumerate(self.train_loader):
|
# for batch_idx, (data, target) in enumerate(self.train_loader):
|
||||||
for batch_idx, batch in enumerate(self.train_loader):
|
for batch_idx, batch in enumerate(self.train_loader):
|
||||||
|
start_time = time.time()
|
||||||
batch = tuple(b.to(self.device, dtype=torch.float) for b in batch)
|
batch = tuple(b.to(self.device, dtype=torch.float) for b in batch)
|
||||||
*train_coeffs, target = batch
|
*train_coeffs, target = batch
|
||||||
# data = data[..., :self.args.input_dim]
|
# data = data[..., :self.args.input_dim]
|
||||||
|
|
@ -101,6 +109,8 @@ class Trainer(object):
|
||||||
if self.args.grad_norm:
|
if self.args.grad_norm:
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
step_time = time.time() - start_time
|
||||||
|
self.stats.record_step_time(step_time, 'train')
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
#log information
|
#log information
|
||||||
|
|
@ -109,6 +119,7 @@ class Trainer(object):
|
||||||
epoch, batch_idx, self.train_per_epoch, loss.item()))
|
epoch, batch_idx, self.train_per_epoch, loss.item()))
|
||||||
train_epoch_loss = total_loss/self.train_per_epoch
|
train_epoch_loss = total_loss/self.train_per_epoch
|
||||||
self.logger.info('**********Train Epoch {}: averaged Loss: {:.6f}'.format(epoch, train_epoch_loss))
|
self.logger.info('**********Train Epoch {}: averaged Loss: {:.6f}'.format(epoch, train_epoch_loss))
|
||||||
|
self.stats.record_memory_usage()
|
||||||
if self.args.tensorboard:
|
if self.args.tensorboard:
|
||||||
self.w.add_scalar(f'train/loss', train_epoch_loss, epoch)
|
self.w.add_scalar(f'train/loss', train_epoch_loss, epoch)
|
||||||
|
|
||||||
|
|
@ -123,6 +134,7 @@ class Trainer(object):
|
||||||
not_improved_count = 0
|
not_improved_count = 0
|
||||||
train_loss_list = []
|
train_loss_list = []
|
||||||
val_loss_list = []
|
val_loss_list = []
|
||||||
|
self.stats.start_training()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for epoch in range(1, self.args.epochs + 1):
|
for epoch in range(1, self.args.epochs + 1):
|
||||||
#epoch_time = time.time()
|
#epoch_time = time.time()
|
||||||
|
|
@ -168,6 +180,13 @@ class Trainer(object):
|
||||||
|
|
||||||
training_time = time.time() - start_time
|
training_time = time.time() - start_time
|
||||||
self.logger.info("Total training time: {:.4f}min, best loss: {:.6f}".format((training_time / 60), best_loss))
|
self.logger.info("Total training time: {:.4f}min, best loss: {:.6f}".format((training_time / 60), best_loss))
|
||||||
|
self.stats.end_training()
|
||||||
|
self.stats.report(self.logger)
|
||||||
|
try:
|
||||||
|
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||||
|
self.logger.info(f"Trainable params: {total_params}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
#save the best model to file
|
#save the best model to file
|
||||||
if not self.args.debug:
|
if not self.args.debug:
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,10 @@ from model.STIDGCN.STIDGCN import STIDGCN
|
||||||
from model.STID.STID import STID
|
from model.STID.STID import STID
|
||||||
from model.STAEFormer.STAEFormer import STAEformer
|
from model.STAEFormer.STAEFormer import STAEformer
|
||||||
from model.EXP.EXP32 import EXP as EXP
|
from model.EXP.EXP32 import EXP as EXP
|
||||||
|
from model.MegaCRN.MegaCRNModel import MegaCRNModel
|
||||||
|
from model.ST_SSL.ST_SSL import STSSLModel
|
||||||
|
from model.STGNRDE.Make_model import make_model as make_nrde_model
|
||||||
|
from model.STAWnet.STAWnet import STAWnet
|
||||||
|
|
||||||
def model_selector(model):
|
def model_selector(model):
|
||||||
match model['type']:
|
match model['type']:
|
||||||
|
|
@ -41,4 +45,8 @@ def model_selector(model):
|
||||||
case 'STID': return STID(model)
|
case 'STID': return STID(model)
|
||||||
case 'STAEFormer': return STAEformer(model)
|
case 'STAEFormer': return STAEformer(model)
|
||||||
case 'EXP': return EXP(model)
|
case 'EXP': return EXP(model)
|
||||||
|
case 'MegaCRN': return MegaCRNModel(model)
|
||||||
|
case 'ST_SSL': return STSSLModel(model)
|
||||||
|
case 'STGNRDE': return make_nrde_model(model)
|
||||||
|
case 'STAWnet': return STAWnet(model)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
from lib.logger import get_logger
|
from lib.logger import get_logger
|
||||||
from lib.loss_function import all_metrics
|
from lib.loss_function import all_metrics
|
||||||
|
from lib.training_stats import TrainingStats
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
|
|
@ -34,6 +35,8 @@ class Trainer:
|
||||||
os.makedirs(args['log_dir'], exist_ok=True)
|
os.makedirs(args['log_dir'], exist_ok=True)
|
||||||
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
||||||
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
||||||
|
# Stats tracker
|
||||||
|
self.stats = TrainingStats(device=args['device'])
|
||||||
|
|
||||||
def _run_epoch(self, epoch, dataloader, mode):
|
def _run_epoch(self, epoch, dataloader, mode):
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
|
|
@ -49,12 +52,13 @@ class Trainer:
|
||||||
with torch.set_grad_enabled(optimizer_step):
|
with torch.set_grad_enabled(optimizer_step):
|
||||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||||
for batch_idx, (data, target) in enumerate(dataloader):
|
for batch_idx, (data, target) in enumerate(dataloader):
|
||||||
|
start_time = time.time()
|
||||||
label = target[..., :self.args['output_dim']]
|
label = target[..., :self.args['output_dim']]
|
||||||
# label = target[..., :self.args['output_dim']]
|
|
||||||
output = self.model(data, labels=label.clone()).to(self.args['device'])
|
output = self.model(data, labels=label.clone()).to(self.args['device'])
|
||||||
|
|
||||||
if self.args['real_value']:
|
if self.args['real_value']:
|
||||||
output = self.scaler.inverse_transform(output)
|
output = self.scaler.inverse_transform(output)
|
||||||
|
label = self.scaler.inverse_transform(label)
|
||||||
|
|
||||||
loss = self.loss(output, label)
|
loss = self.loss(output, label)
|
||||||
if optimizer_step and self.optimizer is not None:
|
if optimizer_step and self.optimizer is not None:
|
||||||
|
|
@ -65,6 +69,8 @@ class Trainer:
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
|
step_time = time.time() - start_time
|
||||||
|
self.stats.record_step_time(step_time, mode)
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
||||||
|
|
@ -78,6 +84,8 @@ class Trainer:
|
||||||
avg_loss = total_loss / len(dataloader)
|
avg_loss = total_loss / len(dataloader)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||||
|
# 记录内存
|
||||||
|
self.stats.record_memory_usage()
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
def train_epoch(self, epoch):
|
def train_epoch(self, epoch):
|
||||||
|
|
@ -94,6 +102,7 @@ class Trainer:
|
||||||
best_loss, best_test_loss = float('inf'), float('inf')
|
best_loss, best_test_loss = float('inf'), float('inf')
|
||||||
not_improved_count = 0
|
not_improved_count = 0
|
||||||
|
|
||||||
|
self.stats.start_training()
|
||||||
self.logger.info("Training process started")
|
self.logger.info("Training process started")
|
||||||
for epoch in range(1, self.args['epochs'] + 1):
|
for epoch in range(1, self.args['epochs'] + 1):
|
||||||
train_epoch_loss = self.train_epoch(epoch)
|
train_epoch_loss = self.train_epoch(epoch)
|
||||||
|
|
@ -126,6 +135,14 @@ class Trainer:
|
||||||
torch.save(best_test_model, self.best_test_path)
|
torch.save(best_test_model, self.best_test_path)
|
||||||
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||||
|
|
||||||
|
# 输出统计与参数
|
||||||
|
self.stats.end_training()
|
||||||
|
self.stats.report(self.logger)
|
||||||
|
try:
|
||||||
|
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||||
|
self.logger.info(f"Trainable params: {total_params}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
self._finalize_training(best_model, best_test_model)
|
self._finalize_training(best_model, best_test_model)
|
||||||
|
|
||||||
def _finalize_training(self, best_model, best_test_model):
|
def _finalize_training(self, best_model, best_test_model):
|
||||||
|
|
@ -154,11 +171,11 @@ class Trainer:
|
||||||
y_pred.append(output)
|
y_pred.append(output)
|
||||||
y_true.append(label)
|
y_true.append(label)
|
||||||
|
|
||||||
if args['real_value']:
|
y_pred = torch.cat(y_pred, dim=0)
|
||||||
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
|
|
||||||
else:
|
|
||||||
y_pred = torch.cat(y_pred, dim=0)
|
|
||||||
y_true = torch.cat(y_true, dim=0)
|
y_true = torch.cat(y_true, dim=0)
|
||||||
|
if args['real_value']:
|
||||||
|
y_pred = scaler.inverse_transform(y_pred)
|
||||||
|
y_true = scaler.inverse_transform(y_true)
|
||||||
|
|
||||||
for t in range(y_true.shape[1]):
|
for t in range(y_true.shape[1]):
|
||||||
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
|
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
from lib.logger import get_logger
|
from lib.logger import get_logger
|
||||||
from lib.loss_function import all_metrics
|
from lib.loss_function import all_metrics
|
||||||
|
from lib.training_stats import TrainingStats
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
|
|
@ -34,6 +35,8 @@ class Trainer:
|
||||||
os.makedirs(args['log_dir'], exist_ok=True)
|
os.makedirs(args['log_dir'], exist_ok=True)
|
||||||
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
||||||
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
||||||
|
# Stats tracker
|
||||||
|
self.stats = TrainingStats(device=args['device'])
|
||||||
|
|
||||||
def _run_epoch(self, epoch, dataloader, mode):
|
def _run_epoch(self, epoch, dataloader, mode):
|
||||||
is_train = (mode == 'train')
|
is_train = (mode == 'train')
|
||||||
|
|
@ -45,6 +48,7 @@ class Trainer:
|
||||||
tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(dataloader):
|
for batch_idx, batch in enumerate(dataloader):
|
||||||
|
start_time = time.time()
|
||||||
# unpack the new cycle_index
|
# unpack the new cycle_index
|
||||||
data, target, cycle_index = batch
|
data, target, cycle_index = batch
|
||||||
data = data.to(self.args['device'])
|
data = data.to(self.args['device'])
|
||||||
|
|
@ -72,6 +76,8 @@ class Trainer:
|
||||||
self.args['max_grad_norm'])
|
self.args['max_grad_norm'])
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
|
step_time = time.time() - start_time
|
||||||
|
self.stats.record_step_time(step_time, mode)
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
|
|
@ -86,6 +92,8 @@ class Trainer:
|
||||||
avg_loss = total_loss / len(dataloader)
|
avg_loss = total_loss / len(dataloader)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||||
|
# 记录内存
|
||||||
|
self.stats.record_memory_usage()
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
def train_epoch(self, epoch):
|
def train_epoch(self, epoch):
|
||||||
|
|
@ -102,6 +110,7 @@ class Trainer:
|
||||||
best_loss, best_test_loss = float('inf'), float('inf')
|
best_loss, best_test_loss = float('inf'), float('inf')
|
||||||
not_improved_count = 0
|
not_improved_count = 0
|
||||||
|
|
||||||
|
self.stats.start_training()
|
||||||
self.logger.info("Training process started")
|
self.logger.info("Training process started")
|
||||||
for epoch in range(1, self.args['epochs'] + 1):
|
for epoch in range(1, self.args['epochs'] + 1):
|
||||||
train_epoch_loss = self.train_epoch(epoch)
|
train_epoch_loss = self.train_epoch(epoch)
|
||||||
|
|
@ -134,6 +143,14 @@ class Trainer:
|
||||||
torch.save(best_test_model, self.best_test_path)
|
torch.save(best_test_model, self.best_test_path)
|
||||||
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||||
|
|
||||||
|
# 输出统计与参数
|
||||||
|
self.stats.end_training()
|
||||||
|
self.stats.report(self.logger)
|
||||||
|
try:
|
||||||
|
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||||
|
self.logger.info(f"Trainable params: {total_params}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
self._finalize_training(best_model, best_test_model)
|
self._finalize_training(best_model, best_test_model)
|
||||||
|
|
||||||
def _finalize_training(self, best_model, best_test_model):
|
def _finalize_training(self, best_model, best_test_model):
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
from lib.logger import get_logger
|
from lib.logger import get_logger
|
||||||
from lib.loss_function import all_metrics
|
from lib.loss_function import all_metrics
|
||||||
|
from lib.training_stats import TrainingStats
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
|
|
@ -34,6 +35,8 @@ class Trainer:
|
||||||
os.makedirs(args['log_dir'], exist_ok=True)
|
os.makedirs(args['log_dir'], exist_ok=True)
|
||||||
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
||||||
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
||||||
|
# Stats tracker
|
||||||
|
self.stats = TrainingStats(device=args['device'])
|
||||||
|
|
||||||
def _run_epoch(self, epoch, dataloader, mode):
|
def _run_epoch(self, epoch, dataloader, mode):
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
|
|
@ -49,6 +52,7 @@ class Trainer:
|
||||||
with torch.set_grad_enabled(optimizer_step):
|
with torch.set_grad_enabled(optimizer_step):
|
||||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||||
for batch_idx, (data, target) in enumerate(dataloader):
|
for batch_idx, (data, target) in enumerate(dataloader):
|
||||||
|
start_time = time.time()
|
||||||
label = target[..., :self.args['output_dim']]
|
label = target[..., :self.args['output_dim']]
|
||||||
output = self.model(data).to(self.args['device'])
|
output = self.model(data).to(self.args['device'])
|
||||||
|
|
||||||
|
|
@ -64,6 +68,9 @@ class Trainer:
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
|
step_time = time.time() - start_time
|
||||||
|
self.stats.record_step_time(step_time, mode)
|
||||||
|
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
||||||
|
|
@ -77,6 +84,8 @@ class Trainer:
|
||||||
avg_loss = total_loss / len(dataloader)
|
avg_loss = total_loss / len(dataloader)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||||
|
# 记录内存
|
||||||
|
self.stats.record_memory_usage()
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
def train_epoch(self, epoch):
|
def train_epoch(self, epoch):
|
||||||
|
|
@ -93,6 +102,7 @@ class Trainer:
|
||||||
best_loss, best_test_loss = float('inf'), float('inf')
|
best_loss, best_test_loss = float('inf'), float('inf')
|
||||||
not_improved_count = 0
|
not_improved_count = 0
|
||||||
|
|
||||||
|
self.stats.start_training()
|
||||||
self.logger.info("Training process started")
|
self.logger.info("Training process started")
|
||||||
for epoch in range(1, self.args['epochs'] + 1):
|
for epoch in range(1, self.args['epochs'] + 1):
|
||||||
train_epoch_loss = self.train_epoch(epoch)
|
train_epoch_loss = self.train_epoch(epoch)
|
||||||
|
|
@ -124,7 +134,14 @@ class Trainer:
|
||||||
torch.save(best_model, self.best_path)
|
torch.save(best_model, self.best_path)
|
||||||
torch.save(best_test_model, self.best_test_path)
|
torch.save(best_test_model, self.best_test_path)
|
||||||
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||||
|
# 输出统计与参数
|
||||||
|
self.stats.end_training()
|
||||||
|
self.stats.report(self.logger)
|
||||||
|
try:
|
||||||
|
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||||
|
self.logger.info(f"Trainable params: {total_params}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
self._finalize_training(best_model, best_test_model)
|
self._finalize_training(best_model, best_test_model)
|
||||||
|
|
||||||
def _finalize_training(self, best_model, best_test_model):
|
def _finalize_training(self, best_model, best_test_model):
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
from lib.logger import get_logger
|
from lib.logger import get_logger
|
||||||
from lib.loss_function import all_metrics
|
from lib.loss_function import all_metrics
|
||||||
|
from lib.training_stats import TrainingStats
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
|
|
@ -35,6 +36,8 @@ class Trainer:
|
||||||
os.makedirs(args['log_dir'], exist_ok=True)
|
os.makedirs(args['log_dir'], exist_ok=True)
|
||||||
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
||||||
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
||||||
|
# Stats tracker
|
||||||
|
self.stats = TrainingStats(device=args['device'])
|
||||||
|
|
||||||
def _run_epoch(self, epoch, dataloader, mode):
|
def _run_epoch(self, epoch, dataloader, mode):
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
|
|
@ -50,6 +53,7 @@ class Trainer:
|
||||||
with torch.set_grad_enabled(optimizer_step):
|
with torch.set_grad_enabled(optimizer_step):
|
||||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||||
for batch_idx, (data, target) in enumerate(dataloader):
|
for batch_idx, (data, target) in enumerate(dataloader):
|
||||||
|
start_time = time.time()
|
||||||
self.batches_seen += 1
|
self.batches_seen += 1
|
||||||
label = target[..., :self.args['output_dim']].clone()
|
label = target[..., :self.args['output_dim']].clone()
|
||||||
output = self.model(data, target, self.batches_seen).to(self.args['device'])
|
output = self.model(data, target, self.batches_seen).to(self.args['device'])
|
||||||
|
|
@ -66,6 +70,9 @@ class Trainer:
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
|
# record step time
|
||||||
|
step_time = time.time() - start_time
|
||||||
|
self.stats.record_step_time(step_time, mode)
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
||||||
|
|
@ -79,6 +86,8 @@ class Trainer:
|
||||||
avg_loss = total_loss / len(dataloader)
|
avg_loss = total_loss / len(dataloader)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||||
|
# 记录内存
|
||||||
|
self.stats.record_memory_usage()
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
def train_epoch(self, epoch):
|
def train_epoch(self, epoch):
|
||||||
|
|
@ -95,6 +104,7 @@ class Trainer:
|
||||||
best_loss, best_test_loss = float('inf'), float('inf')
|
best_loss, best_test_loss = float('inf'), float('inf')
|
||||||
not_improved_count = 0
|
not_improved_count = 0
|
||||||
|
|
||||||
|
self.stats.start_training()
|
||||||
self.logger.info("Training process started")
|
self.logger.info("Training process started")
|
||||||
for epoch in range(1, self.args['epochs'] + 1):
|
for epoch in range(1, self.args['epochs'] + 1):
|
||||||
train_epoch_loss = self.train_epoch(epoch)
|
train_epoch_loss = self.train_epoch(epoch)
|
||||||
|
|
@ -127,6 +137,14 @@ class Trainer:
|
||||||
torch.save(best_test_model, self.best_test_path)
|
torch.save(best_test_model, self.best_test_path)
|
||||||
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||||
|
|
||||||
|
# 输出统计与参数
|
||||||
|
self.stats.end_training()
|
||||||
|
self.stats.report(self.logger)
|
||||||
|
try:
|
||||||
|
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||||
|
self.logger.info(f"Trainable params: {total_params}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
self._finalize_training(best_model, best_test_model)
|
self._finalize_training(best_model, best_test_model)
|
||||||
|
|
||||||
def _finalize_training(self, best_model, best_test_model):
|
def _finalize_training(self, best_model, best_test_model):
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from tqdm import tqdm
|
||||||
from lib.logger import get_logger
|
from lib.logger import get_logger
|
||||||
from lib.loss_function import all_metrics
|
from lib.loss_function import all_metrics
|
||||||
from model.STMLP.STMLP import STMLP
|
from model.STMLP.STMLP import STMLP
|
||||||
|
from lib.training_stats import TrainingStats
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
|
|
@ -43,6 +44,8 @@ class Trainer:
|
||||||
os.makedirs(self.pretrain_dir, exist_ok=True)
|
os.makedirs(self.pretrain_dir, exist_ok=True)
|
||||||
self.logger = get_logger(self.args['log_dir'], name=self.model.__class__.__name__, debug=self.args['debug'])
|
self.logger = get_logger(self.args['log_dir'], name=self.model.__class__.__name__, debug=self.args['debug'])
|
||||||
self.logger.info(f"Experiment log path in: {self.args['log_dir']}")
|
self.logger.info(f"Experiment log path in: {self.args['log_dir']}")
|
||||||
|
# Stats tracker
|
||||||
|
self.stats = TrainingStats(device=args['device'])
|
||||||
|
|
||||||
if self.args['teacher_stu']:
|
if self.args['teacher_stu']:
|
||||||
self.tmodel = self.loadTeacher(args)
|
self.tmodel = self.loadTeacher(args)
|
||||||
|
|
@ -67,6 +70,7 @@ class Trainer:
|
||||||
with torch.set_grad_enabled(optimizer_step):
|
with torch.set_grad_enabled(optimizer_step):
|
||||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||||
for batch_idx, (data, target) in enumerate(dataloader):
|
for batch_idx, (data, target) in enumerate(dataloader):
|
||||||
|
start_time = time.time()
|
||||||
if self.args['teacher_stu']:
|
if self.args['teacher_stu']:
|
||||||
label = target[..., :self.args['output_dim']]
|
label = target[..., :self.args['output_dim']]
|
||||||
output, out_, _ = self.model(data)
|
output, out_, _ = self.model(data)
|
||||||
|
|
@ -100,6 +104,8 @@ class Trainer:
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
|
step_time = time.time() - start_time
|
||||||
|
self.stats.record_step_time(step_time, mode)
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
||||||
|
|
@ -113,6 +119,8 @@ class Trainer:
|
||||||
avg_loss = total_loss / len(dataloader)
|
avg_loss = total_loss / len(dataloader)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||||
|
# 记录内存
|
||||||
|
self.stats.record_memory_usage()
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
def train_epoch(self, epoch):
|
def train_epoch(self, epoch):
|
||||||
|
|
@ -129,6 +137,7 @@ class Trainer:
|
||||||
best_loss, best_test_loss = float('inf'), float('inf')
|
best_loss, best_test_loss = float('inf'), float('inf')
|
||||||
not_improved_count = 0
|
not_improved_count = 0
|
||||||
|
|
||||||
|
self.stats.start_training()
|
||||||
self.logger.info("Training process started")
|
self.logger.info("Training process started")
|
||||||
for epoch in range(1, self.args['epochs'] + 1):
|
for epoch in range(1, self.args['epochs'] + 1):
|
||||||
train_epoch_loss = self.train_epoch(epoch)
|
train_epoch_loss = self.train_epoch(epoch)
|
||||||
|
|
@ -165,6 +174,14 @@ class Trainer:
|
||||||
torch.save(best_test_model, self.best_test_path)
|
torch.save(best_test_model, self.best_test_path)
|
||||||
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||||
|
|
||||||
|
# 输出统计与参数
|
||||||
|
self.stats.end_training()
|
||||||
|
self.stats.report(self.logger)
|
||||||
|
try:
|
||||||
|
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||||
|
self.logger.info(f"Trainable params: {total_params}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
self._finalize_training(best_model, best_test_model)
|
self._finalize_training(best_model, best_test_model)
|
||||||
|
|
||||||
def _finalize_training(self, best_model, best_test_model):
|
def _finalize_training(self, best_model, best_test_model):
|
||||||
|
|
|
||||||
|
|
@ -211,6 +211,13 @@ class Trainer:
|
||||||
|
|
||||||
self._finalize_training(best_model, best_test_model)
|
self._finalize_training(best_model, best_test_model)
|
||||||
|
|
||||||
|
# 输出参数量
|
||||||
|
try:
|
||||||
|
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||||
|
self.logger.info(f"Trainable params: {total_params}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _finalize_training(self, best_model, best_test_model):
|
def _finalize_training(self, best_model, best_test_model):
|
||||||
self.model.load_state_dict(best_model)
|
self.model.load_state_dict(best_model)
|
||||||
self.logger.info("Testing on best validation model")
|
self.logger.info("Testing on best validation model")
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
from lib.logger import get_logger
|
from lib.logger import get_logger
|
||||||
from lib.loss_function import all_metrics
|
from lib.loss_function import all_metrics
|
||||||
|
from lib.training_stats import TrainingStats
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
|
|
@ -34,6 +35,8 @@ class Trainer:
|
||||||
os.makedirs(args['log_dir'], exist_ok=True)
|
os.makedirs(args['log_dir'], exist_ok=True)
|
||||||
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
||||||
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
||||||
|
# Stats tracker
|
||||||
|
self.stats = TrainingStats(device=args['device'])
|
||||||
|
|
||||||
def _run_epoch(self, epoch, dataloader, mode):
|
def _run_epoch(self, epoch, dataloader, mode):
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
|
|
@ -49,6 +52,7 @@ class Trainer:
|
||||||
with torch.set_grad_enabled(optimizer_step):
|
with torch.set_grad_enabled(optimizer_step):
|
||||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||||
for batch_idx, (data, target) in enumerate(dataloader):
|
for batch_idx, (data, target) in enumerate(dataloader):
|
||||||
|
start_time = time.time()
|
||||||
label = target[..., :self.args['output_dim']]
|
label = target[..., :self.args['output_dim']]
|
||||||
output = self.model(data).to(self.args['device'])
|
output = self.model(data).to(self.args['device'])
|
||||||
|
|
||||||
|
|
@ -64,6 +68,8 @@ class Trainer:
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
|
step_time = time.time() - start_time
|
||||||
|
self.stats.record_step_time(step_time, mode)
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
||||||
|
|
@ -77,6 +83,8 @@ class Trainer:
|
||||||
avg_loss = total_loss / len(dataloader)
|
avg_loss = total_loss / len(dataloader)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||||
|
# 记录内存
|
||||||
|
self.stats.record_memory_usage()
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
def train_epoch(self, epoch):
|
def train_epoch(self, epoch):
|
||||||
|
|
@ -93,6 +101,7 @@ class Trainer:
|
||||||
best_loss, best_test_loss = float('inf'), float('inf')
|
best_loss, best_test_loss = float('inf'), float('inf')
|
||||||
not_improved_count = 0
|
not_improved_count = 0
|
||||||
|
|
||||||
|
self.stats.start_training()
|
||||||
self.logger.info("Training process started")
|
self.logger.info("Training process started")
|
||||||
for epoch in range(1, self.args['epochs'] + 1):
|
for epoch in range(1, self.args['epochs'] + 1):
|
||||||
train_epoch_loss = self.train_epoch(epoch)
|
train_epoch_loss = self.train_epoch(epoch)
|
||||||
|
|
@ -125,6 +134,14 @@ class Trainer:
|
||||||
torch.save(best_test_model, self.best_test_path)
|
torch.save(best_test_model, self.best_test_path)
|
||||||
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||||
|
|
||||||
|
# 输出统计与参数
|
||||||
|
self.stats.end_training()
|
||||||
|
self.stats.report(self.logger)
|
||||||
|
try:
|
||||||
|
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||||
|
self.logger.info(f"Trainable params: {total_params}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
self._finalize_training(best_model, best_test_model)
|
self._finalize_training(best_model, best_test_model)
|
||||||
|
|
||||||
def _finalize_training(self, best_model, best_test_model):
|
def _finalize_training(self, best_model, best_test_model):
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
from lib.logger import get_logger
|
from lib.logger import get_logger
|
||||||
from lib.loss_function import all_metrics
|
from lib.loss_function import all_metrics
|
||||||
|
from lib.training_stats import TrainingStats
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
|
|
@ -35,6 +36,8 @@ class Trainer:
|
||||||
os.makedirs(args['log_dir'], exist_ok=True)
|
os.makedirs(args['log_dir'], exist_ok=True)
|
||||||
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
|
||||||
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
self.logger.info(f"Experiment log path in: {args['log_dir']}")
|
||||||
|
# Stats tracker
|
||||||
|
self.stats = TrainingStats(device=args['device'])
|
||||||
self.times = times.to(self.device, dtype=torch.float)
|
self.times = times.to(self.device, dtype=torch.float)
|
||||||
self.w = w
|
self.w = w
|
||||||
|
|
||||||
|
|
@ -52,6 +55,7 @@ class Trainer:
|
||||||
with torch.set_grad_enabled(optimizer_step):
|
with torch.set_grad_enabled(optimizer_step):
|
||||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||||
for batch_idx, batch in enumerate(dataloader):
|
for batch_idx, batch in enumerate(dataloader):
|
||||||
|
start_time = time.time()
|
||||||
batch = tuple(b.to(self.device, dtype=torch.float) for b in batch)
|
batch = tuple(b.to(self.device, dtype=torch.float) for b in batch)
|
||||||
*train_coeffs, target = batch
|
*train_coeffs, target = batch
|
||||||
label = target[..., :self.args['output_dim']]
|
label = target[..., :self.args['output_dim']]
|
||||||
|
|
@ -69,6 +73,8 @@ class Trainer:
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
|
step_time = time.time() - start_time
|
||||||
|
self.stats.record_step_time(step_time, mode)
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
||||||
|
|
@ -82,6 +88,8 @@ class Trainer:
|
||||||
avg_loss = total_loss / len(dataloader)
|
avg_loss = total_loss / len(dataloader)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||||
|
# 记录内存
|
||||||
|
self.stats.record_memory_usage()
|
||||||
return avg_loss
|
return avg_loss
|
||||||
def train_epoch(self, epoch):
|
def train_epoch(self, epoch):
|
||||||
return self._run_epoch(epoch, self.train_loader, 'train')
|
return self._run_epoch(epoch, self.train_loader, 'train')
|
||||||
|
|
@ -97,6 +105,7 @@ class Trainer:
|
||||||
best_loss, best_test_loss = float('inf'), float('inf')
|
best_loss, best_test_loss = float('inf'), float('inf')
|
||||||
not_improved_count = 0
|
not_improved_count = 0
|
||||||
|
|
||||||
|
self.stats.start_training()
|
||||||
self.logger.info("Training process started")
|
self.logger.info("Training process started")
|
||||||
for epoch in range(1, self.args['epochs'] + 1):
|
for epoch in range(1, self.args['epochs'] + 1):
|
||||||
train_epoch_loss = self.train_epoch(epoch)
|
train_epoch_loss = self.train_epoch(epoch)
|
||||||
|
|
@ -129,6 +138,14 @@ class Trainer:
|
||||||
torch.save(best_test_model, self.best_test_path)
|
torch.save(best_test_model, self.best_test_path)
|
||||||
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||||
|
|
||||||
|
# 输出统计与参数
|
||||||
|
self.stats.end_training()
|
||||||
|
self.stats.report(self.logger)
|
||||||
|
try:
|
||||||
|
total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||||
|
self.logger.info(f"Trainable params: {total_params}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
self._finalize_training(best_model, best_test_model)
|
self._finalize_training(best_model, best_test_model)
|
||||||
|
|
||||||
def _finalize_training(self, best_model, best_test_model):
|
def _finalize_training(self, best_model, best_test_model):
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,8 @@ def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader
|
||||||
match args['model']['type']:
|
match args['model']['type']:
|
||||||
case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||||
lr_scheduler, kwargs[0], None)
|
lr_scheduler, kwargs[0], None)
|
||||||
|
case "STGNRDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||||
|
lr_scheduler, kwargs[0], None)
|
||||||
case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||||
lr_scheduler)
|
lr_scheduler)
|
||||||
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue