TrafficWheel/model/DDGCRN/DDGCRN.py

119 lines
5.9 KiB
Python
Executable File

import torch, torch.nn as nn, torch.nn.functional as F
from collections import OrderedDict
class DGCRM(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
super().__init__()
self.node_num, self.input_dim, self.num_layers = node_num, dim_in, num_layers
self.cells = nn.ModuleList(
[DDGCRNCell(node_num, dim_in if i == 0 else dim_out, dim_out, cheb_k, embed_dim) for i in range(num_layers)]
)
def forward(self, x, init_state, node_embeddings):
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
for i in range(self.num_layers):
state, inner = init_state[i].to(x.device), []
for t in range(x.shape[1]):
state = self.cells[i](x[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]])
inner.append(state)
init_state[i] = state
x = torch.stack(inner, dim=1)
return x, init_state
def init_hidden(self, bs):
return torch.stack([cell.init_hidden_state(bs) for cell in self.cells], dim=0)
class DDGCRN(nn.Module):
def __init__(self, args):
super().__init__()
self.num_node, self.input_dim, self.hidden_dim = args['num_nodes'], args['input_dim'], args['rnn_units']
self.output_dim, self.horizon, self.num_layers = args['output_dim'], args['horizon'], args['num_layers']
self.use_day, self.use_week = args['use_day'], args['use_week']
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']), requires_grad=True)
self.T_i_D_emb = nn.Parameter(torch.empty(288, args['embed_dim']))
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['embed_dim']))
self.drop1, self.drop2 = nn.Dropout(0.1), nn.Dropout(0.1)
self.encoder1 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
self.num_layers)
self.encoder2 = DGCRM(self.num_node, self.input_dim, self.hidden_dim, args['cheb_order'], args['embed_dim'],
self.num_layers)
self.end_conv1 = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim))
self.end_conv2 = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim))
self.end_conv3 = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim))
def forward(self, source):
node_embed = self.node_embeddings1
if self.use_day:
node_embed = node_embed * self.T_i_D_emb[(source[..., 1] * 288).long()]
if self.use_week:
node_embed = node_embed * self.D_i_W_emb[source[..., 2].long()]
node_embeddings = [node_embed, self.node_embeddings1]
source = source[..., 0].unsqueeze(-1)
init1 = self.encoder1.init_hidden(source.shape[0])
out, _ = self.encoder1(source, init1, node_embeddings)
out = self.drop1(out[:, -1:, :, :])
out1 = self.end_conv1(out)
src1 = self.end_conv2(out)
src2 = source[:, -self.horizon:, ...] - src1
init2 = self.encoder2.init_hidden(source.shape[0])
out2, _ = self.encoder2(src2, init2, node_embeddings)
out2 = self.drop2(out2[:, -1:, :, :])
return out1 + self.end_conv3(out2)
class DDGCRNCell(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
super().__init__()
self.node_num, self.hidden_dim = node_num, dim_out
self.gate = DGCN(dim_in + dim_out, 2 * dim_out, cheb_k, embed_dim, node_num)
self.update = DGCN(dim_in + dim_out, dim_out, cheb_k, embed_dim, node_num)
def forward(self, x, state, node_embeddings):
inp = torch.cat((x, state), -1)
z_r = torch.sigmoid(self.gate(inp, node_embeddings))
z, r = torch.split(z_r, self.hidden_dim, -1)
hc = torch.tanh(self.update(torch.cat((x, z * state), -1), node_embeddings))
return r * state + (1 - r) * hc
def init_hidden_state(self, bs):
return torch.zeros(bs, self.node_num, self.hidden_dim)
class DGCN(nn.Module):
def __init__(self, dim_in, dim_out, cheb_k, embed_dim, num_nodes):
super().__init__()
self.cheb_k, self.embed_dim = cheb_k, embed_dim
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
self.fc = nn.Sequential(OrderedDict([
('fc1', nn.Linear(dim_in, 16)),
('sigmoid1', nn.Sigmoid()),
('fc2', nn.Linear(16, 2)),
('sigmoid2', nn.Sigmoid()),
('fc3', nn.Linear(2, embed_dim))
]))
# 预注册恒定不变的单位矩阵
self.register_buffer('eye', torch.eye(num_nodes))
def forward(self, x, node_embeddings):
supp1 = self.eye.to(node_embeddings[0].device)
filt = self.fc(x)
nodevec = torch.tanh(node_embeddings[0] * filt)
supp2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supp1)
x_g = torch.stack([torch.einsum("nm,bmc->bnc", supp1, x),
torch.einsum("bnm,bmc->bnc", supp2, x)], dim=1)
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
bias = torch.matmul(node_embeddings[1], self.bias_pool)
return torch.einsum('bnki,nkio->bno', x_g.permute(0, 2, 1, 3), weights) + bias
@staticmethod
def get_laplacian(graph, I, normalize=True):
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
return torch.matmul(torch.matmul(D_inv, graph), D_inv) if normalize else torch.matmul(
torch.matmul(D_inv, graph + I), D_inv)