From c07bf05324841fccc5bca1d98423e9bbf728de9b Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 10 Mar 2025 19:02:42 +0800 Subject: [PATCH] add PDF2SeQ --- config/PDG2SEQ/PEMSD3.yaml | 51 ++++++++++ config/PDG2SEQ/PEMSD4.yaml | 50 ++++++++++ config/PDG2SEQ/PEMSD7.yaml | 51 ++++++++++ config/PDG2SEQ/PEMSD8.yaml | 47 +++++++++ model/DDGCRN/DDGCRN.py | 2 +- model/PDG2SEQ/PDG2Seq.py | 161 ++++++++++++++++++++++++++++++ model/PDG2SEQ/PDG2SeqCell.py | 73 ++++++++++++++ model/PDG2SEQ/PDG2Seq_DGCN.py | 96 ++++++++++++++++++ model/model_selector.py | 2 + run.py | 15 +++ trainer/PDG2SEQ_Trainer.py | 178 ++++++++++++++++++++++++++++++++++ trainer/trainer_selector.py | 3 + 12 files changed, 728 insertions(+), 1 deletion(-) create mode 100644 config/PDG2SEQ/PEMSD3.yaml create mode 100644 config/PDG2SEQ/PEMSD4.yaml create mode 100644 config/PDG2SEQ/PEMSD7.yaml create mode 100644 config/PDG2SEQ/PEMSD8.yaml create mode 100644 model/PDG2SEQ/PDG2Seq.py create mode 100644 model/PDG2SEQ/PDG2SeqCell.py create mode 100644 model/PDG2SEQ/PDG2Seq_DGCN.py create mode 100644 trainer/PDG2SEQ_Trainer.py diff --git a/config/PDG2SEQ/PEMSD3.yaml b/config/PDG2SEQ/PEMSD3.yaml new file mode 100644 index 0000000..438ba5c --- /dev/null +++ b/config/PDG2SEQ/PEMSD3.yaml @@ -0,0 +1,51 @@ +data: + num_nodes: 358 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + cheb_k: 2 + embed_dim: 12 + input_dim: 1 + num_layers: 1 + output_dim: 1 + rnn_units: 64 + use_day: true + use_week: true + lr_decay_step: 10000 + lr_decay_step1: 75,90,120 + time_dim: 8 + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 50 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 10000 + plot: False diff --git a/config/PDG2SEQ/PEMSD4.yaml b/config/PDG2SEQ/PEMSD4.yaml new file mode 100644 index 0000000..6ccd606 --- /dev/null +++ b/config/PDG2SEQ/PEMSD4.yaml @@ -0,0 +1,50 @@ +data: + num_nodes: 307 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + cheb_k: 2 + embed_dim: 12 + input_dim: 1 + num_layers: 1 + output_dim: 1 + rnn_units: 64 + use_day: true + use_week: true + lr_decay_step: 1500 + lr_decay_step1: 60,75,90,120 + time_dim: 16 +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 50 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False diff --git a/config/PDG2SEQ/PEMSD7.yaml b/config/PDG2SEQ/PEMSD7.yaml new file mode 100644 index 0000000..3592e80 --- /dev/null +++ b/config/PDG2SEQ/PEMSD7.yaml @@ -0,0 +1,51 @@ +data: + add_day_in_week: true + add_time_in_day: true + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 12 + lag: 12 + normalizer: std + num_nodes: 883 + steps_per_day: 288 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 +log: + log_step: 3000 + plot: false +model: + cheb_k: 2 + embed_dim: 12 + input_dim: 1 + num_layers: 1 + output_dim: 1 + rnn_units: 64 + use_day: true + use_week: true + lr_decay_step: 12000 + lr_decay_step1: 80,100,120 + time_dim: 20 +test: + mae_thresh: None + mape_thresh: 0.0 +train: + batch_size: 16 + early_stop: true + early_stop_patience: 10 + epochs: 200 + grad_norm: false + loss_func: mae + lr_decay: false + lr_decay_rate: 0.3 + lr_decay_step: + - '5' + - '20' + - '40' + - '70' + lr_init: 0.00075 + max_grad_norm: 5 + real_value: true + seed: 10 + weight_decay: 0 diff --git a/config/PDG2SEQ/PEMSD8.yaml b/config/PDG2SEQ/PEMSD8.yaml new file mode 100644 index 0000000..e4c4b95 --- /dev/null +++ b/config/PDG2SEQ/PEMSD8.yaml @@ -0,0 +1,47 @@ +data: + add_day_in_week: true + add_time_in_day: true + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 12 + lag: 12 + normalizer: std + num_nodes: 170 + steps_per_day: 288 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 +log: + log_step: 2000 + plot: false +model: + cheb_k: 2 + embed_dim: 12 + input_dim: 1 + num_layers: 1 + output_dim: 1 + rnn_units: 64 + use_day: true + use_week: true + lr_decay_step: 2000 + lr_decay_step1: 50,75 + time_dim: 16 +test: + mae_thresh: None + mape_thresh: 0.001 +train: + batch_size: 64 + early_stop: true + early_stop_patience: 15 + epochs: 300 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + seed: 12 + weight_decay: 0 diff --git a/model/DDGCRN/DDGCRN.py b/model/DDGCRN/DDGCRN.py index 35274b0..33baf08 100644 --- a/model/DDGCRN/DDGCRN.py +++ b/model/DDGCRN/DDGCRN.py @@ -98,7 +98,7 @@ class DDGCRN(nn.Module): self.end_conv2 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) self.end_conv3 = nn.Conv2d(1, self.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) - def forward(self, source): + def forward(self, source, **kwargs): """ Forward pass of the DDGCRN model. diff --git a/model/PDG2SEQ/PDG2Seq.py b/model/PDG2SEQ/PDG2Seq.py new file mode 100644 index 0000000..260d2ce --- /dev/null +++ b/model/PDG2SEQ/PDG2Seq.py @@ -0,0 +1,161 @@ +import torch +import torch.nn as nn +from model.PDG2SEQ.PDG2SeqCell import PDG2SeqCell +import numpy as np +class PDG2Seq_Encoder(nn.Module): + def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim, num_layers=1): + super(PDG2Seq_Encoder, self).__init__() + assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.' + self.node_num = node_num + self.input_dim = dim_in + self.num_layers = num_layers + self.PDG2Seq_cells = nn.ModuleList() + self.PDG2Seq_cells.append(PDG2SeqCell(node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim)) + for _ in range(1, num_layers): + self.PDG2Seq_cells.append(PDG2SeqCell(node_num, dim_out, dim_out, cheb_k, embed_dim, time_dim)) + + def forward(self, x, init_state, node_embeddings): + #shape of x: (B, T, N, D) + #shape of init_state: (num_layers, B, N, hidden_dim) + assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim + seq_length = x.shape[1] #x=[batch,steps,nodes,input_dim] + current_inputs = x + output_hidden = [] + for i in range(self.num_layers): + state = init_state[i] #state=[batch,steps,nodes,input_dim] + inner_states = [] + for t in range(seq_length): #如果有两层GRU,则第二层的GGRU的输入是前一层的隐藏状态 + state = self.PDG2Seq_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :], node_embeddings[1][:, t, :], node_embeddings[2]])#state=[batch,steps,nodes,input_dim] + # state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state,[node_embeddings[0], node_embeddings[1]]) + inner_states.append(state) #一个list,里面是每一步的GRU的hidden状态 + output_hidden.append(state) #每层最后一个GRU单元的hidden状态 + current_inputs = torch.stack(inner_states, dim=1) + #拼接成完整的上一层GRU的hidden状态,作为下一层GRRU的输入[batch,steps,nodes,hiddensize] + #current_inputs: the outputs of last layer: (B, T, N, hidden_dim) + #output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim) + #last_state: (B, N, hidden_dim) + return current_inputs, output_hidden + + def init_hidden(self, batch_size): + init_states = [] + for i in range(self.num_layers): + init_states.append(self.PDG2Seq_cells[i].init_hidden_state(batch_size)) + return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim) + + +class PDG2Seq_Dncoder(nn.Module): + def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim, num_layers=1): + super(PDG2Seq_Dncoder, self).__init__() + assert num_layers >= 1, 'At least one DCRNN layer in the Decoder.' + self.node_num = node_num + self.input_dim = dim_in + self.num_layers = num_layers + self.PDG2Seq_cells = nn.ModuleList() + self.PDG2Seq_cells.append(PDG2SeqCell(node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim)) + for _ in range(1, num_layers): + self.PDG2Seq_cells.append(PDG2SeqCell(node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim)) + + def forward(self, xt, init_state, node_embeddings): + # xt: (B, N, D) + # init_state: (num_layers, B, N, hidden_dim) + assert xt.shape[1] == self.node_num and xt.shape[2] == self.input_dim + current_inputs = xt + output_hidden = [] + for i in range(self.num_layers): + state = self.PDG2Seq_cells[i](current_inputs, init_state[i], [node_embeddings[0], node_embeddings[1], node_embeddings[2]]) + output_hidden.append(state) + current_inputs = state + return current_inputs, output_hidden + + +class PDG2Seq(nn.Module): + def __init__(self, args): + super(PDG2Seq, self).__init__() + self.num_node = args['num_nodes'] + self.input_dim = args['input_dim'] + self.hidden_dim = args['rnn_units'] + self.output_dim = args['output_dim'] + self.horizon = args['horizon'] + self.num_layers = args['num_layers'] + self.use_D = args['use_day'] + self.use_W = args['use_week'] + self.cl_decay_steps = args['lr_decay_step'] + self.node_embeddings1 = nn.Parameter(torch.empty(self.num_node, args['embed_dim'])) + self.T_i_D_emb1 = nn.Parameter(torch.empty(288, args['time_dim'])) + self.D_i_W_emb1 = nn.Parameter(torch.empty(7, args['time_dim'])) + self.T_i_D_emb2 = nn.Parameter(torch.empty(288, args['time_dim'])) + self.D_i_W_emb2 = nn.Parameter(torch.empty(7, args['time_dim'])) + + self.encoder = PDG2Seq_Encoder(args['num_nodes'], args['input_dim'], args['rnn_units'], args['cheb_k'], + args['embed_dim'], args['time_dim'], args['num_layers']) + self.decoder = PDG2Seq_Dncoder(args['num_nodes'], args['input_dim'], args['rnn_units'], args['cheb_k'], + args['embed_dim'], args['time_dim'], args['num_layers']) + self.proj = nn.Sequential(nn.Linear(self.hidden_dim, self.output_dim, bias=True)) + self.end_conv = nn.Conv2d(1, args['horizon'] * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) + + def forward(self, source, traget=None, batches_seen=None): + #source: B, T_1, N, D + #target: B, T_2, N, D + + t_i_d_data1 = source[..., 0,-2] + t_i_d_data2 = traget[..., 0,-2] + # T_i_D_emb = self.T_i_D_emb[(t_i_d_data[:, -1, :] * 288).type(torch.LongTensor)] + T_i_D_emb1_en = self.T_i_D_emb1[(t_i_d_data1 * 288).type(torch.LongTensor)] + T_i_D_emb2_en = self.T_i_D_emb2[(t_i_d_data1 * 288).type(torch.LongTensor)] + + T_i_D_emb1_de = self.T_i_D_emb1[(t_i_d_data2 * 288).type(torch.LongTensor)] + T_i_D_emb2_de = self.T_i_D_emb2[(t_i_d_data2 * 288).type(torch.LongTensor)] + if self.use_W: + d_i_w_data1 = source[..., 0,-1] + d_i_w_data2 = traget[..., 0,-1] + # D_i_W_emb = self.D_i_W_emb[(d_i_w_data[:, -1, :]).type(torch.LongTensor)] + D_i_W_emb1_en = self.D_i_W_emb1[(d_i_w_data1).type(torch.LongTensor)] + D_i_W_emb2_en = self.D_i_W_emb2[(d_i_w_data1).type(torch.LongTensor)] + + D_i_W_emb1_de = self.D_i_W_emb1[(d_i_w_data2).type(torch.LongTensor)] + D_i_W_emb2_de = self.D_i_W_emb2[(d_i_w_data2).type(torch.LongTensor)] + + node_embedding_en1 = torch.mul(T_i_D_emb1_en, D_i_W_emb1_en) + node_embedding_en2 = torch.mul(T_i_D_emb2_en, D_i_W_emb2_en) + + node_embedding_de1 = torch.mul(T_i_D_emb1_de, D_i_W_emb1_de) + node_embedding_de2 = torch.mul(T_i_D_emb2_de, D_i_W_emb2_de) + else: + node_embedding_en1 = T_i_D_emb1_en + node_embedding_en2 = T_i_D_emb2_en + + node_embedding_de1 = T_i_D_emb1_de + node_embedding_de2 = T_i_D_emb2_de + + + en_node_embeddings=[node_embedding_en1, node_embedding_en2, self.node_embeddings1] + + source = source[..., 0].unsqueeze(-1) + + init_state = self.encoder.init_hidden(source.shape[0]).to(source.device) # [2,64,307,64] 前面是2是因为有两层GRU + state, _ = self.encoder(source, init_state, en_node_embeddings) # B, T, N, hidden + state = state[:, -1:, :, :].squeeze(1) + + ht_list = [state] * self.num_layers + + go = torch.zeros((source.shape[0], self.num_node, self.output_dim), device=source.device) + out = [] + for t in range(self.horizon): + state, ht_list = self.decoder(go, ht_list, [node_embedding_de1[:, t, :], node_embedding_de2[:, t, :], self.node_embeddings1]) + go = self.proj(state) + out.append(go) + if self.training: #这里的课程学习用了给予一定概率用真实值代替预测值来学习的教师-学生学习法(名字忘了,大概跟着有关) + c = np.random.uniform(0, 1) + if c < self._compute_sampling_threshold(batches_seen): #如果满足条件,则用真实值代替预测值训练 + go = traget[:, t, :, 0].unsqueeze(-1) + output = torch.stack(out, dim=1) + + + return output + + def _compute_sampling_threshold(self, batches_seen): + x = self.cl_decay_steps / ( + self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps)) + return x + + diff --git a/model/PDG2SEQ/PDG2SeqCell.py b/model/PDG2SEQ/PDG2SeqCell.py new file mode 100644 index 0000000..ba8f51e --- /dev/null +++ b/model/PDG2SEQ/PDG2SeqCell.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from model.PDG2SEQ.PDG2Seq_DGCN import PDG2Seq_GCN +from collections import OrderedDict +import torch.nn.functional as F +class FC(nn.Module): + def __init__(self, dim_in, dim_out): + super(FC, self).__init__() + self.hyperGNN_dim = 16 + self.middle_dim = 2 + self.mlp=nn.Sequential( #疑问,这里为什么要用三层linear来做,为什么激活函数是sigmoid + OrderedDict([('fc1', nn.Linear(dim_in, self.hyperGNN_dim)), + #('sigmoid1', nn.ReLU()), + ('sigmoid1', nn.Sigmoid()), + ('fc2', nn.Linear(self.hyperGNN_dim, self.middle_dim)), + #('sigmoid1', nn.ReLU()), + ('sigmoid2', nn.Sigmoid()), + ('fc3', nn.Linear(self.middle_dim, dim_out))])) + + def forward(self, x): + + ho = self.mlp(x) + + return ho + +class PDG2SeqCell(nn.Module): #这个模块只进行GRU内部的更新,所以需要修改的是AGCN里面的东西 + def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim): + super(PDG2SeqCell, self).__init__() + self.node_num = node_num + self.hidden_dim = dim_out + self.gate = PDG2Seq_GCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim, time_dim) + self.update = PDG2Seq_GCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim, time_dim) + self.fc1 = FC(dim_in + self.hidden_dim, time_dim) + self.fc2 = FC(dim_in + self.hidden_dim, time_dim) + + def forward(self, x, state, node_embeddings): + #x: B, num_nodes, input_dim + #state: B, num_nodes, hidden_dim + state = state.to(x.device) + input_and_state = torch.cat((x, state), dim=-1) + filter1 = self.fc1(input_and_state) + filter2 = self.fc2(input_and_state) + + nodevec1 = torch.tanh(torch.einsum('bd,bnd->bnd', node_embeddings[0], filter1)) #[B,N,dim_in] + nodevec2 = torch.tanh(torch.einsum('bd,bnd->bnd', node_embeddings[1], filter2)) # [B,N,dim_in] + + + adj = torch.matmul(nodevec1, nodevec2.transpose(2, 1)) - torch.matmul( + nodevec2, nodevec1.transpose(2, 1)) + + adj1 = PDG2SeqCell.preprocessing(F.relu(adj)) + adj2 = PDG2SeqCell.preprocessing(F.relu(-adj.transpose(-2, -1))) + + + adj = [adj1, adj2] + + + z_r = torch.sigmoid(self.gate(input_and_state, adj, node_embeddings[2])) + z, r = torch.split(z_r, self.hidden_dim, dim=-1) + candidate = torch.cat((x, z*state), dim=-1) + hc = torch.tanh(self.update(candidate, adj, node_embeddings[2])) + h = r*state + (1-r)*hc + return h + + def init_hidden_state(self, batch_size): + return torch.zeros(batch_size, self.node_num, self.hidden_dim) + + def preprocessing(adj): #处理动态矩阵可能不含有对角线元素的问题 + num_nodes= adj.shape[-1] + adj = adj + torch.eye(num_nodes).to(adj.device) + x= torch.unsqueeze(adj.sum(-1), -1) + adj = adj / x # D = torch.diag_embed(torch.sum(adj, dim=-1) ** (-1)) adj =torch.matmul(D, adj) + return adj \ No newline at end of file diff --git a/model/PDG2SEQ/PDG2Seq_DGCN.py b/model/PDG2SEQ/PDG2Seq_DGCN.py new file mode 100644 index 0000000..3059db9 --- /dev/null +++ b/model/PDG2SEQ/PDG2Seq_DGCN.py @@ -0,0 +1,96 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +import math +import numpy as np +import time +from collections import OrderedDict + +class FC(nn.Module): + def __init__(self, dim_in, dim_out): + super(FC, self).__init__() + self.hyperGNN_dim = 16 + self.middle_dim = 2 + self.mlp=nn.Sequential( #疑问,这里为什么要用三层linear来做,为什么激活函数是sigmoid + OrderedDict([('fc1', nn.Linear(dim_in, self.hyperGNN_dim)), + #('sigmoid1', nn.ReLU()), + ('sigmoid1', nn.Sigmoid()), + ('fc2', nn.Linear(self.hyperGNN_dim, self.middle_dim)), + #('sigmoid1', nn.ReLU()), + ('sigmoid2', nn.Sigmoid()), + ('fc3', nn.Linear(self.middle_dim, dim_out))])) + + def forward(self, x): + + ho = self.mlp(x) + + return ho + + + + +class PDG2Seq_GCN(nn.Module): + def __init__(self, dim_in, dim_out, cheb_k, embed_dim, time_dim): + super(PDG2Seq_GCN, self).__init__() + self.cheb_k = cheb_k + self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k*2+1, dim_in, dim_out)) + self.weights = nn.Parameter(torch.FloatTensor(cheb_k*2+1,dim_in, dim_out)) + # self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)) + # self.weights = nn.Parameter(torch.FloatTensor(cheb_k,dim_in, dim_out)) + self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) + self.bias = nn.Parameter(torch.FloatTensor(dim_out)) + self.hyperGNN_dim = 16 + self.middle_dim = 2 + self.embed_dim = embed_dim + self.time_dim = time_dim + self.gcn = gcn(cheb_k) + self.fc1 = FC(dim_in, time_dim) + self.fc2 = FC(dim_in, time_dim) + + def forward(self, x, adj, node_embedding): + #x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N] + #output shape [B, N, C] + + + x_g = self.gcn(x, adj) + + weights = torch.einsum('nd,dkio->nkio', node_embedding, self.weights_pool) #[B,N,embed_dim]*[embed_dim,chen_k,dim_in,dim_out] =[B,N,cheb_k,dim_in,dim_out] + #[N, cheb_k, dim_in, dim_out]=[nodes,cheb_k,hidden_size,output_dim] + bias = torch.matmul(node_embedding, self.bias_pool) #N, dim_out #[che_k,nodes,nodes]* [batch,nodes,dim_in]=[B, cheb_k, N, dim_in] + + x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in + # x_gconv = torch.einsum('bnki,bnkio->bno', x_g, weights) + bias #b, N, dim_out + x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out + # x_gconv = torch.einsum('bnki,kio->bno', x_g, self.weights) + self.bias #[B,N,cheb_k,dim_in] *[N,cheb_k,dim_in,dim_out] =[B,N,dim_out] + + return x_gconv + + +class nconv(nn.Module): + def __init__(self): + super(nconv,self).__init__() + + def forward(self, x, A): + # x = torch.einsum("bnm,bmc->bnc", A, x)#[batch_size, D, num_nodes, num_steps] [N,N] [batch_size, num_steps, num_nodes, D] + x = torch.einsum("bnm,bmc->bnc", A,x) # [batch_size, D, num_nodes, num_steps] [N,N] [batch_size, num_steps, num_nodes, D] + return x.contiguous() + +class gcn(nn.Module): + def __init__(self,k=2): + super(gcn,self).__init__() + self.nconv = nconv() + self.k = k + + def forward(self,x,support): + out = [x] + for a in support: + x1 = self.nconv(x,a) #先做一次图扩散卷积 + out.append(x1) #放入输出列表中 + for k in range(2, self.k + 1): #在对经过卷积的X1进行多级运算,得到一系列扩散卷积结果,都存入out中 + x2 = self.nconv(x1,a) #这里的order应该就是进行多少次扩散卷积运算,默认是2,那么range(2, self.order + 1)就是(2,3)也就是算两次就结束了 + out.append(x2) + x1 = x2 + h = torch.stack(out, dim=1) + #h = torch.cat(out,dim=1) #拼接结果 + return h + diff --git a/model/model_selector.py b/model/model_selector.py index dab0b0b..52e6fcb 100644 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -12,6 +12,7 @@ from model.GWN.GraphWaveNet import gwnet from model.STFGNN.STFGNN import STFGNN from model.STSGCN.STSGCN import STSGCN from model.STGODE.STGODE import ODEGCN +from model.PDG2SEQ.PDG2Seq import PDG2Seq def model_selector(model): match model['type']: @@ -29,4 +30,5 @@ def model_selector(model): case 'STFGNN': return STFGNN(model) case 'STSGCN': return STSGCN(model) case 'STGODE': return ODEGCN(model) + case 'PDG2SEQ': return PDG2Seq(model) diff --git a/run.py b/run.py index b6bd82c..ffb66f2 100644 --- a/run.py +++ b/run.py @@ -1,5 +1,7 @@ import os import shutil +from torchview import draw_graph + # 检查数据集完整性 from lib.Download_data import check_and_download_data @@ -34,6 +36,19 @@ def main(): # Initialize model model = init_model(args['model'], device=args['device']) + if args['mode'] == "draw": + dummy_input = torch.randn(64,12,307,3) + model_graph = draw_graph(model, + input_data = dummy_input, + device=args['device'], + show_shapes=True, + save_graph=True, + graph_name=f"{args['model']['type']}_graph", + directory="./", + format="png" + ) + return 0 + # Load dataset train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( args, diff --git a/trainer/PDG2SEQ_Trainer.py b/trainer/PDG2SEQ_Trainer.py new file mode 100644 index 0000000..00750a1 --- /dev/null +++ b/trainer/PDG2SEQ_Trainer.py @@ -0,0 +1,178 @@ +import math +import os +import time +import copy +from tqdm import tqdm + +import torch +from lib.logger import get_logger +from lib.loss_function import all_metrics + + +class Trainer: + def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, + scaler, args, lr_scheduler=None): + self.model = model + self.loss = loss + self.optimizer = optimizer + self.train_loader = train_loader + self.val_loader = val_loader + self.test_loader = test_loader + self.scaler = scaler + self.args = args + self.lr_scheduler = lr_scheduler + self.train_per_epoch = len(train_loader) + self.val_per_epoch = len(val_loader) if val_loader else 0 + self.batches_seen = 0 + + # Paths for saving models and logs + self.best_path = os.path.join(args['log_dir'], 'best_model.pth') + self.best_test_path = os.path.join(args['log_dir'], 'best_test_model.pth') + self.loss_figure_path = os.path.join(args['log_dir'], 'loss.png') + + # Initialize logger + if not os.path.isdir(args['log_dir']) and not args['debug']: + os.makedirs(args['log_dir'], exist_ok=True) + self.logger = get_logger(args['log_dir'], name=self.model.__class__.__name__, debug=args['debug']) + self.logger.info(f"Experiment log path in: {args['log_dir']}") + + def _run_epoch(self, epoch, dataloader, mode): + if mode == 'train': + self.model.train() + optimizer_step = True + else: + self.model.eval() + optimizer_step = False + + total_loss = 0 + epoch_time = time.time() + + with torch.set_grad_enabled(optimizer_step): + with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: + for batch_idx, (data, target) in enumerate(dataloader): + self.batches_seen += 1 + label = target[..., :self.args['output_dim']].clone() + output = self.model(data, target, self.batches_seen).to(self.args['device']) + + if self.args['real_value']: + output = self.scaler.inverse_transform(output) + + loss = self.loss(output, label) + if optimizer_step and self.optimizer is not None: + self.optimizer.zero_grad() + loss.backward() + + if self.args['grad_norm']: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) + self.optimizer.step() + + total_loss += loss.item() + + if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0: + self.logger.info( + f'Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}') + + # 更新 tqdm 的进度 + pbar.update(1) + pbar.set_postfix(loss=loss.item()) + + avg_loss = total_loss / len(dataloader) + self.logger.info( + f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') + return avg_loss + + def train_epoch(self, epoch): + return self._run_epoch(epoch, self.train_loader, 'train') + + def val_epoch(self, epoch): + return self._run_epoch(epoch, self.val_loader or self.test_loader, 'val') + + def test_epoch(self, epoch): + return self._run_epoch(epoch, self.test_loader, 'test') + + def train(self): + best_model, best_test_model = None, None + best_loss, best_test_loss = float('inf'), float('inf') + not_improved_count = 0 + + self.logger.info("Training process started") + for epoch in range(1, self.args['epochs'] + 1): + train_epoch_loss = self.train_epoch(epoch) + val_epoch_loss = self.val_epoch(epoch) + test_epoch_loss = self.test_epoch(epoch) + + if train_epoch_loss > 1e6: + self.logger.warning('Gradient explosion detected. Ending...') + break + + if val_epoch_loss < best_loss: + best_loss = val_epoch_loss + not_improved_count = 0 + best_model = copy.deepcopy(self.model.state_dict()) + self.logger.info('Best validation model saved!') + else: + not_improved_count += 1 + + if self.args['early_stop'] and not_improved_count == self.args['early_stop_patience']: + self.logger.info( + f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.") + break + + if test_epoch_loss < best_test_loss: + best_test_loss = test_epoch_loss + best_test_model = copy.deepcopy(self.model.state_dict()) + + if not self.args['debug']: + torch.save(best_model, self.best_path) + torch.save(best_test_model, self.best_test_path) + self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") + + self._finalize_training(best_model, best_test_model) + + def _finalize_training(self, best_model, best_test_model): + self.model.load_state_dict(best_model) + self.logger.info("Testing on best validation model") + self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) + + self.model.load_state_dict(best_test_model) + self.logger.info("Testing on best test model") + self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) + + @staticmethod + def test(model, args, data_loader, scaler, logger, path=None): + if path: + checkpoint = torch.load(path) + model.load_state_dict(checkpoint['state_dict']) + model.to(args['device']) + + model.eval() + y_pred, y_true = [], [] + + with torch.no_grad(): + for data, target in data_loader: + label = target[..., :args['output_dim']].clone() + output = model(data, target) + y_pred.append(output) + y_true.append(label) + + if args['real_value']: + y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + else: + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) + + # 你在这里需要把y_pred和y_true保存下来 + # torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1] + # torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1] + + for t in range(y_true.shape[1]): + mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], + args['mae_thresh'], args['mape_thresh']) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + mae, rmse, mape = all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh']) + logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + @staticmethod + def _compute_sampling_threshold(global_step, k): + return k / (k + math.exp(global_step / k)) diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py index 24eb1a9..eaad3ab 100644 --- a/trainer/trainer_selector.py +++ b/trainer/trainer_selector.py @@ -1,6 +1,7 @@ from trainer.Trainer import Trainer from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer +from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, @@ -10,5 +11,7 @@ def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader lr_scheduler, kwargs[0], None) case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], lr_scheduler) + case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], + lr_scheduler) case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], lr_scheduler)