TrafficWheel/model/PDG2SEQ/PDG2Seq.py

162 lines
8.4 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from model.PDG2SEQ.PDG2SeqCell import PDG2SeqCell
import numpy as np
class PDG2Seq_Encoder(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim, num_layers=1):
super(PDG2Seq_Encoder, self).__init__()
assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.'
self.node_num = node_num
self.input_dim = dim_in
self.num_layers = num_layers
self.PDG2Seq_cells = nn.ModuleList()
self.PDG2Seq_cells.append(PDG2SeqCell(node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim))
for _ in range(1, num_layers):
self.PDG2Seq_cells.append(PDG2SeqCell(node_num, dim_out, dim_out, cheb_k, embed_dim, time_dim))
def forward(self, x, init_state, node_embeddings):
#shape of x: (B, T, N, D)
#shape of init_state: (num_layers, B, N, hidden_dim)
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
seq_length = x.shape[1] #x=[batch,steps,nodes,input_dim]
current_inputs = x
output_hidden = []
for i in range(self.num_layers):
state = init_state[i] #state=[batch,steps,nodes,input_dim]
inner_states = []
for t in range(seq_length): #如果有两层GRU则第二层的GGRU的输入是前一层的隐藏状态
state = self.PDG2Seq_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :], node_embeddings[1][:, t, :], node_embeddings[2]])#state=[batch,steps,nodes,input_dim]
# state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state,[node_embeddings[0], node_embeddings[1]])
inner_states.append(state) #一个list里面是每一步的GRU的hidden状态
output_hidden.append(state) #每层最后一个GRU单元的hidden状态
current_inputs = torch.stack(inner_states, dim=1)
#拼接成完整的上一层GRU的hidden状态作为下一层GRRU的输入[batch,steps,nodes,hiddensize]
#current_inputs: the outputs of last layer: (B, T, N, hidden_dim)
#output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim)
#last_state: (B, N, hidden_dim)
return current_inputs, output_hidden
def init_hidden(self, batch_size):
init_states = []
for i in range(self.num_layers):
init_states.append(self.PDG2Seq_cells[i].init_hidden_state(batch_size))
return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim)
class PDG2Seq_Dncoder(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim, num_layers=1):
super(PDG2Seq_Dncoder, self).__init__()
assert num_layers >= 1, 'At least one DCRNN layer in the Decoder.'
self.node_num = node_num
self.input_dim = dim_in
self.num_layers = num_layers
self.PDG2Seq_cells = nn.ModuleList()
self.PDG2Seq_cells.append(PDG2SeqCell(node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim))
for _ in range(1, num_layers):
self.PDG2Seq_cells.append(PDG2SeqCell(node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim))
def forward(self, xt, init_state, node_embeddings):
# xt: (B, N, D)
# init_state: (num_layers, B, N, hidden_dim)
assert xt.shape[1] == self.node_num and xt.shape[2] == self.input_dim
current_inputs = xt
output_hidden = []
for i in range(self.num_layers):
state = self.PDG2Seq_cells[i](current_inputs, init_state[i], [node_embeddings[0], node_embeddings[1], node_embeddings[2]])
output_hidden.append(state)
current_inputs = state
return current_inputs, output_hidden
class PDG2Seq(nn.Module):
def __init__(self, args):
super(PDG2Seq, self).__init__()
self.num_node = args['num_nodes']
self.input_dim = args['input_dim']
self.hidden_dim = args['rnn_units']
self.output_dim = args['output_dim']
self.horizon = args['horizon']
self.num_layers = args['num_layers']
self.use_D = args['use_day']
self.use_W = args['use_week']
self.cl_decay_steps = args['lr_decay_step']
self.node_embeddings1 = nn.Parameter(torch.empty(self.num_node, args['embed_dim']))
self.T_i_D_emb1 = nn.Parameter(torch.empty(288, args['time_dim']))
self.D_i_W_emb1 = nn.Parameter(torch.empty(7, args['time_dim']))
self.T_i_D_emb2 = nn.Parameter(torch.empty(288, args['time_dim']))
self.D_i_W_emb2 = nn.Parameter(torch.empty(7, args['time_dim']))
self.encoder = PDG2Seq_Encoder(args['num_nodes'], args['input_dim'], args['rnn_units'], args['cheb_k'],
args['embed_dim'], args['time_dim'], args['num_layers'])
self.decoder = PDG2Seq_Dncoder(args['num_nodes'], args['input_dim'], args['rnn_units'], args['cheb_k'],
args['embed_dim'], args['time_dim'], args['num_layers'])
self.proj = nn.Sequential(nn.Linear(self.hidden_dim, self.output_dim, bias=True))
self.end_conv = nn.Conv2d(1, args['horizon'] * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
def forward(self, source, traget=None, batches_seen=None):
#source: B, T_1, N, D
#target: B, T_2, N, D
t_i_d_data1 = source[..., 0,-2]
t_i_d_data2 = traget[..., 0,-2]
# T_i_D_emb = self.T_i_D_emb[(t_i_d_data[:, -1, :] * 288).type(torch.LongTensor)]
T_i_D_emb1_en = self.T_i_D_emb1[(t_i_d_data1 * 288).type(torch.LongTensor)]
T_i_D_emb2_en = self.T_i_D_emb2[(t_i_d_data1 * 288).type(torch.LongTensor)]
T_i_D_emb1_de = self.T_i_D_emb1[(t_i_d_data2 * 288).type(torch.LongTensor)]
T_i_D_emb2_de = self.T_i_D_emb2[(t_i_d_data2 * 288).type(torch.LongTensor)]
if self.use_W:
d_i_w_data1 = source[..., 0,-1]
d_i_w_data2 = traget[..., 0,-1]
# D_i_W_emb = self.D_i_W_emb[(d_i_w_data[:, -1, :]).type(torch.LongTensor)]
D_i_W_emb1_en = self.D_i_W_emb1[(d_i_w_data1).type(torch.LongTensor)]
D_i_W_emb2_en = self.D_i_W_emb2[(d_i_w_data1).type(torch.LongTensor)]
D_i_W_emb1_de = self.D_i_W_emb1[(d_i_w_data2).type(torch.LongTensor)]
D_i_W_emb2_de = self.D_i_W_emb2[(d_i_w_data2).type(torch.LongTensor)]
node_embedding_en1 = torch.mul(T_i_D_emb1_en, D_i_W_emb1_en)
node_embedding_en2 = torch.mul(T_i_D_emb2_en, D_i_W_emb2_en)
node_embedding_de1 = torch.mul(T_i_D_emb1_de, D_i_W_emb1_de)
node_embedding_de2 = torch.mul(T_i_D_emb2_de, D_i_W_emb2_de)
else:
node_embedding_en1 = T_i_D_emb1_en
node_embedding_en2 = T_i_D_emb2_en
node_embedding_de1 = T_i_D_emb1_de
node_embedding_de2 = T_i_D_emb2_de
en_node_embeddings=[node_embedding_en1, node_embedding_en2, self.node_embeddings1]
source = source[..., 0].unsqueeze(-1)
init_state = self.encoder.init_hidden(source.shape[0]).to(source.device) # [2,64,307,64] 前面是2是因为有两层GRU
state, _ = self.encoder(source, init_state, en_node_embeddings) # B, T, N, hidden
state = state[:, -1:, :, :].squeeze(1)
ht_list = [state] * self.num_layers
go = torch.zeros((source.shape[0], self.num_node, self.output_dim), device=source.device)
out = []
for t in range(self.horizon):
state, ht_list = self.decoder(go, ht_list, [node_embedding_de1[:, t, :], node_embedding_de2[:, t, :], self.node_embeddings1])
go = self.proj(state)
out.append(go)
if self.training: #这里的课程学习用了给予一定概率用真实值代替预测值来学习的教师-学生学习法(名字忘了,大概跟着有关)
c = np.random.uniform(0, 1)
if c < self._compute_sampling_threshold(batches_seen): #如果满足条件,则用真实值代替预测值训练
go = traget[:, t, :, 0].unsqueeze(-1)
output = torch.stack(out, dim=1)
return output
def _compute_sampling_threshold(self, batches_seen):
x = self.cl_decay_steps / (
self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps))
return x