add PDF2SeQ

This commit is contained in:
czzhangheng 2025-03-10 19:02:42 +08:00
parent 0a9ac1a025
commit c07bf05324
12 changed files with 728 additions and 1 deletions

View File

@ -0,0 +1,51 @@
data:
num_nodes: 358
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
cheb_k: 2
embed_dim: 12
input_dim: 1
num_layers: 1
output_dim: 1
rnn_units: 64
use_day: true
use_week: true
lr_decay_step: 10000
lr_decay_step1: 75,90,120
time_dim: 8
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 50
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 10000
plot: False

View File

@ -0,0 +1,50 @@
data:
num_nodes: 307
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
cheb_k: 2
embed_dim: 12
input_dim: 1
num_layers: 1
output_dim: 1
rnn_units: 64
use_day: true
use_week: true
lr_decay_step: 1500
lr_decay_step1: 60,75,90,120
time_dim: 16
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 50
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 200
plot: False

View File

@ -0,0 +1,51 @@
data:
add_day_in_week: true
add_time_in_day: true
column_wise: false
days_per_week: 7
default_graph: true
horizon: 12
lag: 12
normalizer: std
num_nodes: 883
steps_per_day: 288
test_ratio: 0.2
tod: false
val_ratio: 0.2
log:
log_step: 3000
plot: false
model:
cheb_k: 2
embed_dim: 12
input_dim: 1
num_layers: 1
output_dim: 1
rnn_units: 64
use_day: true
use_week: true
lr_decay_step: 12000
lr_decay_step1: 80,100,120
time_dim: 20
test:
mae_thresh: None
mape_thresh: 0.0
train:
batch_size: 16
early_stop: true
early_stop_patience: 10
epochs: 200
grad_norm: false
loss_func: mae
lr_decay: false
lr_decay_rate: 0.3
lr_decay_step:
- '5'
- '20'
- '40'
- '70'
lr_init: 0.00075
max_grad_norm: 5
real_value: true
seed: 10
weight_decay: 0

View File

@ -0,0 +1,47 @@
data:
add_day_in_week: true
add_time_in_day: true
column_wise: false
days_per_week: 7
default_graph: true
horizon: 12
lag: 12
normalizer: std
num_nodes: 170
steps_per_day: 288
test_ratio: 0.2
tod: false
val_ratio: 0.2
log:
log_step: 2000
plot: false
model:
cheb_k: 2
embed_dim: 12
input_dim: 1
num_layers: 1
output_dim: 1
rnn_units: 64
use_day: true
use_week: true
lr_decay_step: 2000
lr_decay_step1: 50,75
time_dim: 16
test:
mae_thresh: None
mape_thresh: 0.001
train:
batch_size: 64
early_stop: true
early_stop_patience: 15
epochs: 300
grad_norm: false
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.003
max_grad_norm: 5
real_value: true
seed: 12
weight_decay: 0

View File

@ -98,7 +98,7 @@ class DDGCRN(nn.Module):
self.end_conv2 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
self.end_conv3 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
def forward(self, source):
def forward(self, source, **kwargs):
"""
Forward pass of the DDGCRN model.

161
model/PDG2SEQ/PDG2Seq.py Normal file
View File

@ -0,0 +1,161 @@
import torch
import torch.nn as nn
from model.PDG2SEQ.PDG2SeqCell import PDG2SeqCell
import numpy as np
class PDG2Seq_Encoder(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim, num_layers=1):
super(PDG2Seq_Encoder, self).__init__()
assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.'
self.node_num = node_num
self.input_dim = dim_in
self.num_layers = num_layers
self.PDG2Seq_cells = nn.ModuleList()
self.PDG2Seq_cells.append(PDG2SeqCell(node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim))
for _ in range(1, num_layers):
self.PDG2Seq_cells.append(PDG2SeqCell(node_num, dim_out, dim_out, cheb_k, embed_dim, time_dim))
def forward(self, x, init_state, node_embeddings):
#shape of x: (B, T, N, D)
#shape of init_state: (num_layers, B, N, hidden_dim)
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
seq_length = x.shape[1] #x=[batch,steps,nodes,input_dim]
current_inputs = x
output_hidden = []
for i in range(self.num_layers):
state = init_state[i] #state=[batch,steps,nodes,input_dim]
inner_states = []
for t in range(seq_length): #如果有两层GRU则第二层的GGRU的输入是前一层的隐藏状态
state = self.PDG2Seq_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :], node_embeddings[1][:, t, :], node_embeddings[2]])#state=[batch,steps,nodes,input_dim]
# state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state,[node_embeddings[0], node_embeddings[1]])
inner_states.append(state) #一个list里面是每一步的GRU的hidden状态
output_hidden.append(state) #每层最后一个GRU单元的hidden状态
current_inputs = torch.stack(inner_states, dim=1)
#拼接成完整的上一层GRU的hidden状态作为下一层GRRU的输入[batch,steps,nodes,hiddensize]
#current_inputs: the outputs of last layer: (B, T, N, hidden_dim)
#output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim)
#last_state: (B, N, hidden_dim)
return current_inputs, output_hidden
def init_hidden(self, batch_size):
init_states = []
for i in range(self.num_layers):
init_states.append(self.PDG2Seq_cells[i].init_hidden_state(batch_size))
return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim)
class PDG2Seq_Dncoder(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim, num_layers=1):
super(PDG2Seq_Dncoder, self).__init__()
assert num_layers >= 1, 'At least one DCRNN layer in the Decoder.'
self.node_num = node_num
self.input_dim = dim_in
self.num_layers = num_layers
self.PDG2Seq_cells = nn.ModuleList()
self.PDG2Seq_cells.append(PDG2SeqCell(node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim))
for _ in range(1, num_layers):
self.PDG2Seq_cells.append(PDG2SeqCell(node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim))
def forward(self, xt, init_state, node_embeddings):
# xt: (B, N, D)
# init_state: (num_layers, B, N, hidden_dim)
assert xt.shape[1] == self.node_num and xt.shape[2] == self.input_dim
current_inputs = xt
output_hidden = []
for i in range(self.num_layers):
state = self.PDG2Seq_cells[i](current_inputs, init_state[i], [node_embeddings[0], node_embeddings[1], node_embeddings[2]])
output_hidden.append(state)
current_inputs = state
return current_inputs, output_hidden
class PDG2Seq(nn.Module):
def __init__(self, args):
super(PDG2Seq, self).__init__()
self.num_node = args['num_nodes']
self.input_dim = args['input_dim']
self.hidden_dim = args['rnn_units']
self.output_dim = args['output_dim']
self.horizon = args['horizon']
self.num_layers = args['num_layers']
self.use_D = args['use_day']
self.use_W = args['use_week']
self.cl_decay_steps = args['lr_decay_step']
self.node_embeddings1 = nn.Parameter(torch.empty(self.num_node, args['embed_dim']))
self.T_i_D_emb1 = nn.Parameter(torch.empty(288, args['time_dim']))
self.D_i_W_emb1 = nn.Parameter(torch.empty(7, args['time_dim']))
self.T_i_D_emb2 = nn.Parameter(torch.empty(288, args['time_dim']))
self.D_i_W_emb2 = nn.Parameter(torch.empty(7, args['time_dim']))
self.encoder = PDG2Seq_Encoder(args['num_nodes'], args['input_dim'], args['rnn_units'], args['cheb_k'],
args['embed_dim'], args['time_dim'], args['num_layers'])
self.decoder = PDG2Seq_Dncoder(args['num_nodes'], args['input_dim'], args['rnn_units'], args['cheb_k'],
args['embed_dim'], args['time_dim'], args['num_layers'])
self.proj = nn.Sequential(nn.Linear(self.hidden_dim, self.output_dim, bias=True))
self.end_conv = nn.Conv2d(1, args['horizon'] * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
def forward(self, source, traget=None, batches_seen=None):
#source: B, T_1, N, D
#target: B, T_2, N, D
t_i_d_data1 = source[..., 0,-2]
t_i_d_data2 = traget[..., 0,-2]
# T_i_D_emb = self.T_i_D_emb[(t_i_d_data[:, -1, :] * 288).type(torch.LongTensor)]
T_i_D_emb1_en = self.T_i_D_emb1[(t_i_d_data1 * 288).type(torch.LongTensor)]
T_i_D_emb2_en = self.T_i_D_emb2[(t_i_d_data1 * 288).type(torch.LongTensor)]
T_i_D_emb1_de = self.T_i_D_emb1[(t_i_d_data2 * 288).type(torch.LongTensor)]
T_i_D_emb2_de = self.T_i_D_emb2[(t_i_d_data2 * 288).type(torch.LongTensor)]
if self.use_W:
d_i_w_data1 = source[..., 0,-1]
d_i_w_data2 = traget[..., 0,-1]
# D_i_W_emb = self.D_i_W_emb[(d_i_w_data[:, -1, :]).type(torch.LongTensor)]
D_i_W_emb1_en = self.D_i_W_emb1[(d_i_w_data1).type(torch.LongTensor)]
D_i_W_emb2_en = self.D_i_W_emb2[(d_i_w_data1).type(torch.LongTensor)]
D_i_W_emb1_de = self.D_i_W_emb1[(d_i_w_data2).type(torch.LongTensor)]
D_i_W_emb2_de = self.D_i_W_emb2[(d_i_w_data2).type(torch.LongTensor)]
node_embedding_en1 = torch.mul(T_i_D_emb1_en, D_i_W_emb1_en)
node_embedding_en2 = torch.mul(T_i_D_emb2_en, D_i_W_emb2_en)
node_embedding_de1 = torch.mul(T_i_D_emb1_de, D_i_W_emb1_de)
node_embedding_de2 = torch.mul(T_i_D_emb2_de, D_i_W_emb2_de)
else:
node_embedding_en1 = T_i_D_emb1_en
node_embedding_en2 = T_i_D_emb2_en
node_embedding_de1 = T_i_D_emb1_de
node_embedding_de2 = T_i_D_emb2_de
en_node_embeddings=[node_embedding_en1, node_embedding_en2, self.node_embeddings1]
source = source[..., 0].unsqueeze(-1)
init_state = self.encoder.init_hidden(source.shape[0]).to(source.device) # [2,64,307,64] 前面是2是因为有两层GRU
state, _ = self.encoder(source, init_state, en_node_embeddings) # B, T, N, hidden
state = state[:, -1:, :, :].squeeze(1)
ht_list = [state] * self.num_layers
go = torch.zeros((source.shape[0], self.num_node, self.output_dim), device=source.device)
out = []
for t in range(self.horizon):
state, ht_list = self.decoder(go, ht_list, [node_embedding_de1[:, t, :], node_embedding_de2[:, t, :], self.node_embeddings1])
go = self.proj(state)
out.append(go)
if self.training: #这里的课程学习用了给予一定概率用真实值代替预测值来学习的教师-学生学习法(名字忘了,大概跟着有关)
c = np.random.uniform(0, 1)
if c < self._compute_sampling_threshold(batches_seen): #如果满足条件,则用真实值代替预测值训练
go = traget[:, t, :, 0].unsqueeze(-1)
output = torch.stack(out, dim=1)
return output
def _compute_sampling_threshold(self, batches_seen):
x = self.cl_decay_steps / (
self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps))
return x

View File

@ -0,0 +1,73 @@
import torch
import torch.nn as nn
from model.PDG2SEQ.PDG2Seq_DGCN import PDG2Seq_GCN
from collections import OrderedDict
import torch.nn.functional as F
class FC(nn.Module):
def __init__(self, dim_in, dim_out):
super(FC, self).__init__()
self.hyperGNN_dim = 16
self.middle_dim = 2
self.mlp=nn.Sequential( #疑问这里为什么要用三层linear来做为什么激活函数是sigmoid
OrderedDict([('fc1', nn.Linear(dim_in, self.hyperGNN_dim)),
#('sigmoid1', nn.ReLU()),
('sigmoid1', nn.Sigmoid()),
('fc2', nn.Linear(self.hyperGNN_dim, self.middle_dim)),
#('sigmoid1', nn.ReLU()),
('sigmoid2', nn.Sigmoid()),
('fc3', nn.Linear(self.middle_dim, dim_out))]))
def forward(self, x):
ho = self.mlp(x)
return ho
class PDG2SeqCell(nn.Module): #这个模块只进行GRU内部的更新所以需要修改的是AGCN里面的东西
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim):
super(PDG2SeqCell, self).__init__()
self.node_num = node_num
self.hidden_dim = dim_out
self.gate = PDG2Seq_GCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim, time_dim)
self.update = PDG2Seq_GCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim, time_dim)
self.fc1 = FC(dim_in + self.hidden_dim, time_dim)
self.fc2 = FC(dim_in + self.hidden_dim, time_dim)
def forward(self, x, state, node_embeddings):
#x: B, num_nodes, input_dim
#state: B, num_nodes, hidden_dim
state = state.to(x.device)
input_and_state = torch.cat((x, state), dim=-1)
filter1 = self.fc1(input_and_state)
filter2 = self.fc2(input_and_state)
nodevec1 = torch.tanh(torch.einsum('bd,bnd->bnd', node_embeddings[0], filter1)) #[B,N,dim_in]
nodevec2 = torch.tanh(torch.einsum('bd,bnd->bnd', node_embeddings[1], filter2)) # [B,N,dim_in]
adj = torch.matmul(nodevec1, nodevec2.transpose(2, 1)) - torch.matmul(
nodevec2, nodevec1.transpose(2, 1))
adj1 = PDG2SeqCell.preprocessing(F.relu(adj))
adj2 = PDG2SeqCell.preprocessing(F.relu(-adj.transpose(-2, -1)))
adj = [adj1, adj2]
z_r = torch.sigmoid(self.gate(input_and_state, adj, node_embeddings[2]))
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
candidate = torch.cat((x, z*state), dim=-1)
hc = torch.tanh(self.update(candidate, adj, node_embeddings[2]))
h = r*state + (1-r)*hc
return h
def init_hidden_state(self, batch_size):
return torch.zeros(batch_size, self.node_num, self.hidden_dim)
def preprocessing(adj): #处理动态矩阵可能不含有对角线元素的问题
num_nodes= adj.shape[-1]
adj = adj + torch.eye(num_nodes).to(adj.device)
x= torch.unsqueeze(adj.sum(-1), -1)
adj = adj / x # D = torch.diag_embed(torch.sum(adj, dim=-1) ** (-1)) adj =torch.matmul(D, adj)
return adj

View File

@ -0,0 +1,96 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
import numpy as np
import time
from collections import OrderedDict
class FC(nn.Module):
def __init__(self, dim_in, dim_out):
super(FC, self).__init__()
self.hyperGNN_dim = 16
self.middle_dim = 2
self.mlp=nn.Sequential( #疑问这里为什么要用三层linear来做为什么激活函数是sigmoid
OrderedDict([('fc1', nn.Linear(dim_in, self.hyperGNN_dim)),
#('sigmoid1', nn.ReLU()),
('sigmoid1', nn.Sigmoid()),
('fc2', nn.Linear(self.hyperGNN_dim, self.middle_dim)),
#('sigmoid1', nn.ReLU()),
('sigmoid2', nn.Sigmoid()),
('fc3', nn.Linear(self.middle_dim, dim_out))]))
def forward(self, x):
ho = self.mlp(x)
return ho
class PDG2Seq_GCN(nn.Module):
def __init__(self, dim_in, dim_out, cheb_k, embed_dim, time_dim):
super(PDG2Seq_GCN, self).__init__()
self.cheb_k = cheb_k
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k*2+1, dim_in, dim_out))
self.weights = nn.Parameter(torch.FloatTensor(cheb_k*2+1,dim_in, dim_out))
# self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
# self.weights = nn.Parameter(torch.FloatTensor(cheb_k,dim_in, dim_out))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
self.hyperGNN_dim = 16
self.middle_dim = 2
self.embed_dim = embed_dim
self.time_dim = time_dim
self.gcn = gcn(cheb_k)
self.fc1 = FC(dim_in, time_dim)
self.fc2 = FC(dim_in, time_dim)
def forward(self, x, adj, node_embedding):
#x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N]
#output shape [B, N, C]
x_g = self.gcn(x, adj)
weights = torch.einsum('nd,dkio->nkio', node_embedding, self.weights_pool) #[B,N,embed_dim]*[embed_dim,chen_k,dim_in,dim_out] =[B,N,cheb_k,dim_in,dim_out]
#[N, cheb_k, dim_in, dim_out]=[nodes,cheb_k,hidden_size,output_dim]
bias = torch.matmul(node_embedding, self.bias_pool) #N, dim_out #[che_k,nodes,nodes]* [batch,nodes,dim_in]=[B, cheb_k, N, dim_in]
x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in
# x_gconv = torch.einsum('bnki,bnkio->bno', x_g, weights) + bias #b, N, dim_out
x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out
# x_gconv = torch.einsum('bnki,kio->bno', x_g, self.weights) + self.bias #[B,N,cheb_k,dim_in] *[N,cheb_k,dim_in,dim_out] =[B,N,dim_out]
return x_gconv
class nconv(nn.Module):
def __init__(self):
super(nconv,self).__init__()
def forward(self, x, A):
# x = torch.einsum("bnm,bmc->bnc", A, x)#[batch_size, D, num_nodes, num_steps] [N,N] [batch_size, num_steps, num_nodes, D]
x = torch.einsum("bnm,bmc->bnc", A,x) # [batch_size, D, num_nodes, num_steps] [N,N] [batch_size, num_steps, num_nodes, D]
return x.contiguous()
class gcn(nn.Module):
def __init__(self,k=2):
super(gcn,self).__init__()
self.nconv = nconv()
self.k = k
def forward(self,x,support):
out = [x]
for a in support:
x1 = self.nconv(x,a) #先做一次图扩散卷积
out.append(x1) #放入输出列表中
for k in range(2, self.k + 1): #在对经过卷积的X1进行多级运算得到一系列扩散卷积结果都存入out中
x2 = self.nconv(x1,a) #这里的order应该就是进行多少次扩散卷积运算默认是2那么range(2, self.order + 1)就是(2,3)也就是算两次就结束了
out.append(x2)
x1 = x2
h = torch.stack(out, dim=1)
#h = torch.cat(out,dim=1) #拼接结果
return h

View File

@ -12,6 +12,7 @@ from model.GWN.GraphWaveNet import gwnet
from model.STFGNN.STFGNN import STFGNN
from model.STSGCN.STSGCN import STSGCN
from model.STGODE.STGODE import ODEGCN
from model.PDG2SEQ.PDG2Seq import PDG2Seq
def model_selector(model):
match model['type']:
@ -29,4 +30,5 @@ def model_selector(model):
case 'STFGNN': return STFGNN(model)
case 'STSGCN': return STSGCN(model)
case 'STGODE': return ODEGCN(model)
case 'PDG2SEQ': return PDG2Seq(model)

15
run.py
View File

@ -1,5 +1,7 @@
import os
import shutil
from torchview import draw_graph
# 检查数据集完整性
from lib.Download_data import check_and_download_data
@ -34,6 +36,19 @@ def main():
# Initialize model
model = init_model(args['model'], device=args['device'])
if args['mode'] == "draw":
dummy_input = torch.randn(64,12,307,3)
model_graph = draw_graph(model,
input_data = dummy_input,
device=args['device'],
show_shapes=True,
save_graph=True,
graph_name=f"{args['model']['type']}_graph",
directory="./",
format="png"
)
return 0
# Load dataset
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
args,

178
trainer/PDG2SEQ_Trainer.py Normal file
View File

@ -0,0 +1,178 @@
import math
import os
import time
import copy
from tqdm import tqdm
import torch
from lib.logger import get_logger
from lib.loss_function import all_metrics
class Trainer:
def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader,
scaler, args, lr_scheduler=None):
self.model = model
self.loss = loss
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.scaler = scaler
self.args = args
self.lr_scheduler = lr_scheduler
self.train_per_epoch = len(train_loader)
self.val_per_epoch = len(val_loader) if val_loader else 0
self.batches_seen = 0
# Paths for saving models and logs
self.best_path = os.path.join(args['log_dir'], 'best_model.pth')
self.best_test_path = os.path.join(args['log_dir'], 'best_test_model.pth')
self.loss_figure_path = os.path.join(args['log_dir'], 'loss.png')
# Initialize logger
if not os.path.isdir(args['log_dir']) and not args['debug']:
os.makedirs(args['log_dir'], exist_ok=True)
self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug'])
self.logger.info(f"Experiment log path in: {args['log_dir']}")
def _run_epoch(self, epoch, dataloader, mode):
if mode == 'train':
self.model.train()
optimizer_step = True
else:
self.model.eval()
optimizer_step = False
total_loss = 0
epoch_time = time.time()
with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader):
self.batches_seen += 1
label = target[..., :self.args['output_dim']].clone()
output = self.model(data, target, self.batches_seen).to(self.args['device'])
if self.args['real_value']:
output = self.scaler.inverse_transform(output)
loss = self.loss(output, label)
if optimizer_step and self.optimizer is not None:
self.optimizer.zero_grad()
loss.backward()
if self.args['grad_norm']:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
self.logger.info(
f'Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}')
# 更新 tqdm 的进度
pbar.update(1)
pbar.set_postfix(loss=loss.item())
avg_loss = total_loss / len(dataloader)
self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
return avg_loss
def train_epoch(self, epoch):
return self._run_epoch(epoch, self.train_loader, 'train')
def val_epoch(self, epoch):
return self._run_epoch(epoch, self.val_loader or self.test_loader, 'val')
def test_epoch(self, epoch):
return self._run_epoch(epoch, self.test_loader, 'test')
def train(self):
best_model, best_test_model = None, None
best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0
self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch)
val_epoch_loss = self.val_epoch(epoch)
test_epoch_loss = self.test_epoch(epoch)
if train_epoch_loss > 1e6:
self.logger.warning('Gradient explosion detected. Ending...')
break
if val_epoch_loss < best_loss:
best_loss = val_epoch_loss
not_improved_count = 0
best_model = copy.deepcopy(self.model.state_dict())
self.logger.info('Best validation model saved!')
else:
not_improved_count += 1
if self.args['early_stop'] and not_improved_count == self.args['early_stop_patience']:
self.logger.info(
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
break
if test_epoch_loss < best_test_loss:
best_test_loss = test_epoch_loss
best_test_model = copy.deepcopy(self.model.state_dict())
if not self.args['debug']:
torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model)
self.logger.info("Testing on best validation model")
self.test(self.model, self.args, self.test_loader, self.scaler, self.logger)
self.model.load_state_dict(best_test_model)
self.logger.info("Testing on best test model")
self.test(self.model, self.args, self.test_loader, self.scaler, self.logger)
@staticmethod
def test(model, args, data_loader, scaler, logger, path=None):
if path:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict'])
model.to(args['device'])
model.eval()
y_pred, y_true = [], []
with torch.no_grad():
for data, target in data_loader:
label = target[..., :args['output_dim']].clone()
output = model(data, target)
y_pred.append(output)
y_true.append(label)
if args['real_value']:
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
else:
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
# 你在这里需要把y_pred和y_true保存下来
# torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1]
# torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1]
for t in range(y_true.shape[1]):
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
args['mae_thresh'], args['mape_thresh'])
logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
mae, rmse, mape = all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
@staticmethod
def _compute_sampling_threshold(global_step, k):
return k / (k + math.exp(global_step / k))

View File

@ -1,6 +1,7 @@
from trainer.Trainer import Trainer
from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer
from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer
from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer
def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
@ -10,5 +11,7 @@ def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader
lr_scheduler, kwargs[0], None)
case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler)
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler)
case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler)