230 lines
9.3 KiB
Python
Executable File
230 lines
9.3 KiB
Python
Executable File
import torch
|
||
import torch.nn as nn
|
||
from model.PDG2SEQ.PDG2SeqCell import PDG2SeqCell
|
||
import numpy as np
|
||
|
||
|
||
class PDG2Seq_Encoder(nn.Module):
|
||
def __init__(
|
||
self, node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim, num_layers=1
|
||
):
|
||
super(PDG2Seq_Encoder, self).__init__()
|
||
assert num_layers >= 1, "At least one DCRNN layer in the Encoder."
|
||
self.node_num = node_num
|
||
self.input_dim = dim_in
|
||
self.num_layers = num_layers
|
||
self.PDG2Seq_cells = nn.ModuleList()
|
||
self.PDG2Seq_cells.append(
|
||
PDG2SeqCell(node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim)
|
||
)
|
||
for _ in range(1, num_layers):
|
||
self.PDG2Seq_cells.append(
|
||
PDG2SeqCell(node_num, dim_out, dim_out, cheb_k, embed_dim, time_dim)
|
||
)
|
||
|
||
def forward(self, x, init_state, node_embeddings):
|
||
# shape of x: (B, T, N, D)
|
||
# shape of init_state: (num_layers, B, N, hidden_dim)
|
||
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
||
seq_length = x.shape[1] # x=[batch,steps,nodes,input_dim]
|
||
current_inputs = x
|
||
output_hidden = []
|
||
for i in range(self.num_layers):
|
||
state = init_state[i] # state=[batch,steps,nodes,input_dim]
|
||
inner_states = []
|
||
for t in range(
|
||
seq_length
|
||
): # 如果有两层GRU,则第二层的GGRU的输入是前一层的隐藏状态
|
||
state = self.PDG2Seq_cells[i](
|
||
current_inputs[:, t, :, :],
|
||
state,
|
||
[
|
||
node_embeddings[0][:, t, :],
|
||
node_embeddings[1][:, t, :],
|
||
node_embeddings[2],
|
||
],
|
||
) # state=[batch,steps,nodes,input_dim]
|
||
# state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state,[node_embeddings[0], node_embeddings[1]])
|
||
inner_states.append(state) # 一个list,里面是每一步的GRU的hidden状态
|
||
output_hidden.append(state) # 每层最后一个GRU单元的hidden状态
|
||
current_inputs = torch.stack(inner_states, dim=1)
|
||
# 拼接成完整的上一层GRU的hidden状态,作为下一层GRRU的输入[batch,steps,nodes,hiddensize]
|
||
# current_inputs: the outputs of last layer: (B, T, N, hidden_dim)
|
||
# output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim)
|
||
# last_state: (B, N, hidden_dim)
|
||
return current_inputs, output_hidden
|
||
|
||
def init_hidden(self, batch_size):
|
||
init_states = []
|
||
for i in range(self.num_layers):
|
||
init_states.append(self.PDG2Seq_cells[i].init_hidden_state(batch_size))
|
||
return torch.stack(init_states, dim=0) # (num_layers, B, N, hidden_dim)
|
||
|
||
|
||
class PDG2Seq_Dncoder(nn.Module):
|
||
def __init__(
|
||
self, node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim, num_layers=1
|
||
):
|
||
super(PDG2Seq_Dncoder, self).__init__()
|
||
assert num_layers >= 1, "At least one DCRNN layer in the Decoder."
|
||
self.node_num = node_num
|
||
self.input_dim = dim_in
|
||
self.num_layers = num_layers
|
||
self.PDG2Seq_cells = nn.ModuleList()
|
||
self.PDG2Seq_cells.append(
|
||
PDG2SeqCell(node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim)
|
||
)
|
||
for _ in range(1, num_layers):
|
||
self.PDG2Seq_cells.append(
|
||
PDG2SeqCell(node_num, dim_in, dim_out, cheb_k, embed_dim, time_dim)
|
||
)
|
||
|
||
def forward(self, xt, init_state, node_embeddings):
|
||
# xt: (B, N, D)
|
||
# init_state: (num_layers, B, N, hidden_dim)
|
||
assert xt.shape[1] == self.node_num and xt.shape[2] == self.input_dim
|
||
current_inputs = xt
|
||
output_hidden = []
|
||
for i in range(self.num_layers):
|
||
state = self.PDG2Seq_cells[i](
|
||
current_inputs,
|
||
init_state[i],
|
||
[node_embeddings[0], node_embeddings[1], node_embeddings[2]],
|
||
)
|
||
output_hidden.append(state)
|
||
current_inputs = state
|
||
return current_inputs, output_hidden
|
||
|
||
|
||
class PDG2Seq(nn.Module):
|
||
def __init__(self, args):
|
||
super(PDG2Seq, self).__init__()
|
||
self.num_node = args["num_nodes"]
|
||
self.input_dim = args["input_dim"]
|
||
self.hidden_dim = args["rnn_units"]
|
||
self.output_dim = args["output_dim"]
|
||
self.horizon = args["horizon"]
|
||
self.num_layers = args["num_layers"]
|
||
self.use_D = args["use_day"]
|
||
self.use_W = args["use_week"]
|
||
self.cl_decay_steps = args["lr_decay_step"]
|
||
self.node_embeddings1 = nn.Parameter(
|
||
torch.empty(self.num_node, args["embed_dim"])
|
||
)
|
||
self.T_i_D_emb1 = nn.Parameter(torch.empty(288, args["time_dim"]))
|
||
self.D_i_W_emb1 = nn.Parameter(torch.empty(7, args["time_dim"]))
|
||
self.T_i_D_emb2 = nn.Parameter(torch.empty(288, args["time_dim"]))
|
||
self.D_i_W_emb2 = nn.Parameter(torch.empty(7, args["time_dim"]))
|
||
|
||
self.encoder = PDG2Seq_Encoder(
|
||
args["num_nodes"],
|
||
args["input_dim"],
|
||
args["rnn_units"],
|
||
args["cheb_k"],
|
||
args["embed_dim"],
|
||
args["time_dim"],
|
||
args["num_layers"],
|
||
)
|
||
self.decoder = PDG2Seq_Dncoder(
|
||
args["num_nodes"],
|
||
args["input_dim"],
|
||
args["rnn_units"],
|
||
args["cheb_k"],
|
||
args["embed_dim"],
|
||
args["time_dim"],
|
||
args["num_layers"],
|
||
)
|
||
self.proj = nn.Sequential(
|
||
nn.Linear(self.hidden_dim, self.output_dim, bias=True)
|
||
)
|
||
self.end_conv = nn.Conv2d(
|
||
1,
|
||
args["horizon"] * self.output_dim,
|
||
kernel_size=(1, self.hidden_dim),
|
||
bias=True,
|
||
)
|
||
|
||
def forward(self, source, traget=None, batches_seen=None):
|
||
# source: B, T_1, N, D
|
||
# target: B, T_2, N, D
|
||
|
||
t_i_d_data1 = source[..., 0, -2]
|
||
t_i_d_data2 = traget[..., 0, -2]
|
||
# T_i_D_emb = self.T_i_D_emb[(t_i_d_data[:, -1, :] * 288).type(torch.LongTensor)]
|
||
T_i_D_emb1_en = self.T_i_D_emb1[(t_i_d_data1 * 288).type(torch.LongTensor)]
|
||
T_i_D_emb2_en = self.T_i_D_emb2[(t_i_d_data1 * 288).type(torch.LongTensor)]
|
||
|
||
T_i_D_emb1_de = self.T_i_D_emb1[(t_i_d_data2 * 288).type(torch.LongTensor)]
|
||
T_i_D_emb2_de = self.T_i_D_emb2[(t_i_d_data2 * 288).type(torch.LongTensor)]
|
||
if self.use_W:
|
||
d_i_w_data1 = source[..., 0, -1]
|
||
d_i_w_data2 = traget[..., 0, -1]
|
||
# D_i_W_emb = self.D_i_W_emb[(d_i_w_data[:, -1, :]).type(torch.LongTensor)]
|
||
D_i_W_emb1_en = self.D_i_W_emb1[(d_i_w_data1).type(torch.LongTensor)]
|
||
D_i_W_emb2_en = self.D_i_W_emb2[(d_i_w_data1).type(torch.LongTensor)]
|
||
|
||
D_i_W_emb1_de = self.D_i_W_emb1[(d_i_w_data2).type(torch.LongTensor)]
|
||
D_i_W_emb2_de = self.D_i_W_emb2[(d_i_w_data2).type(torch.LongTensor)]
|
||
|
||
node_embedding_en1 = torch.mul(T_i_D_emb1_en, D_i_W_emb1_en)
|
||
node_embedding_en2 = torch.mul(T_i_D_emb2_en, D_i_W_emb2_en)
|
||
|
||
node_embedding_de1 = torch.mul(T_i_D_emb1_de, D_i_W_emb1_de)
|
||
node_embedding_de2 = torch.mul(T_i_D_emb2_de, D_i_W_emb2_de)
|
||
else:
|
||
node_embedding_en1 = T_i_D_emb1_en
|
||
node_embedding_en2 = T_i_D_emb2_en
|
||
|
||
node_embedding_de1 = T_i_D_emb1_de
|
||
node_embedding_de2 = T_i_D_emb2_de
|
||
|
||
en_node_embeddings = [
|
||
node_embedding_en1,
|
||
node_embedding_en2,
|
||
self.node_embeddings1,
|
||
]
|
||
|
||
source = source[..., 0].unsqueeze(-1)
|
||
|
||
init_state = self.encoder.init_hidden(source.shape[0]).to(
|
||
source.device
|
||
) # [2,64,307,64] 前面是2是因为有两层GRU
|
||
state, _ = self.encoder(
|
||
source, init_state, en_node_embeddings
|
||
) # B, T, N, hidden
|
||
state = state[:, -1:, :, :].squeeze(1)
|
||
|
||
ht_list = [state] * self.num_layers
|
||
|
||
go = torch.zeros(
|
||
(source.shape[0], self.num_node, self.output_dim), device=source.device
|
||
)
|
||
out = []
|
||
for t in range(self.horizon):
|
||
state, ht_list = self.decoder(
|
||
go,
|
||
ht_list,
|
||
[
|
||
node_embedding_de1[:, t, :],
|
||
node_embedding_de2[:, t, :],
|
||
self.node_embeddings1,
|
||
],
|
||
)
|
||
go = self.proj(state)
|
||
out.append(go)
|
||
if self.training: # 这里的课程学习用了给予一定概率用真实值代替预测值来学习的教师-学生学习法(名字忘了,大概跟着有关)
|
||
c = np.random.uniform(0, 1)
|
||
if c < self._compute_sampling_threshold(
|
||
batches_seen
|
||
): # 如果满足条件,则用真实值代替预测值训练
|
||
go = traget[:, t, :, 0].unsqueeze(-1)
|
||
output = torch.stack(out, dim=1)
|
||
|
||
return output
|
||
|
||
def _compute_sampling_threshold(self, batches_seen):
|
||
x = self.cl_decay_steps / (
|
||
self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps)
|
||
)
|
||
return x
|