TrafficWheel/model/EXP/EXP0.py

124 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch, torch.nn as nn, torch.nn.functional as F
from collections import OrderedDict
class DGCRM(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
super().__init__()
self.node_num, self.input_dim, self.num_layers = node_num, dim_in, num_layers
self.cells = nn.ModuleList([
DDGCRNCell(node_num, dim_in if i == 0 else dim_out, dim_out, cheb_k, embed_dim)
for i in range(num_layers)
])
def forward(self, x, init_state, node_embeddings):
# x: (B, T, N, D)
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
for i in range(self.num_layers):
state, inner = init_state[i].to(x.device), []
for t in range(x.shape[1]):
state = self.cells[i](x[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]])
inner.append(state)
init_state[i] = state
x = torch.stack(inner, dim=1)
return x, init_state
def init_hidden(self, bs):
return torch.stack([cell.init_hidden_state(bs) for cell in self.cells], dim=0)
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.num_node, self.input_dim, self.hidden_dim = args['num_nodes'], args['input_dim'], args['rnn_units']
self.output_dim, self.horizon, self.num_layers = args['output_dim'], args['horizon'], args['num_layers']
self.use_day, self.use_week = args['use_day'], args['use_week']
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']))
# 第二套节点向量已不再使用,减少参数
self.T_i_D_emb = nn.Parameter(torch.empty(288, args['embed_dim']))
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['embed_dim']))
self.drop = nn.Dropout(0.1)
# 采用单编码器,减少一次前向计算
self.encoder = DGCRM(self.num_node, self.input_dim, self.hidden_dim,
args['cheb_order'], args['embed_dim'], self.num_layers)
# 主预测头:基础预测
self.base_conv = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim))
# 残差预测头:利用最近时刻的输入信息进行修正,输入通道为 hidden_dim+1
self.res_conv = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim + 1))
def forward(self, source):
# source: (B, T, N, D_total) 其中第0维为主观测第1、2维为时间编码
node_embed = self.node_embeddings1
if self.use_day:
node_embed = node_embed * self.T_i_D_emb[(source[..., 1] * 288).long()]
if self.use_week:
node_embed = node_embed * self.D_i_W_emb[source[..., 2].long()]
node_embeddings = [node_embed, self.node_embeddings1]
inp = source[..., 0].unsqueeze(-1) # (B, T, N, 1)
init = self.encoder.init_hidden(inp.shape[0])
enc_out, _ = self.encoder(inp, init, node_embeddings)
# 取最后时刻的隐状态作为表示shape: (B, 1, N, hidden_dim)
rep = self.drop(enc_out[:, -1:, :, :])
# 基础预测
base = self.base_conv(rep)
# 为修正分支拼接最近时刻的原始输入(取最后一帧)作为残差补偿信息,扩充通道数
res_in = torch.cat([rep, inp[:, -1:, :, :]], dim=-1) # (B, 1, N, hidden_dim+1)
res = self.res_conv(res_in)
return base + res
class DDGCRNCell(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
super().__init__()
self.node_num, self.hidden_dim = node_num, dim_out
self.gate = DGCN(dim_in + dim_out, 2 * dim_out, cheb_k, embed_dim, node_num)
self.update = DGCN(dim_in + dim_out, dim_out, cheb_k, embed_dim, node_num)
self.ln = nn.LayerNorm(dim_out)
def forward(self, x, state, node_embeddings):
inp = torch.cat((x, state), -1)
z_r = torch.sigmoid(self.gate(inp, node_embeddings))
z, r = torch.split(z_r, self.hidden_dim, -1)
hc = torch.tanh(self.update(torch.cat((x, z * state), -1), node_embeddings))
out = r * state + (1 - r) * hc
return self.ln(out)
def init_hidden_state(self, bs):
return torch.zeros(bs, self.node_num, self.hidden_dim)
class DGCN(nn.Module):
def __init__(self, dim_in, dim_out, cheb_k, embed_dim, num_nodes):
super().__init__()
self.cheb_k, self.embed_dim = cheb_k, embed_dim
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
self.fc = nn.Sequential(OrderedDict([
('fc1', nn.Linear(dim_in, 16)),
('sigmoid1', nn.Sigmoid()),
('fc2', nn.Linear(16, 2)),
('sigmoid2', nn.Sigmoid()),
('fc3', nn.Linear(2, embed_dim))
]))
# 预注册单位矩阵,避免每次构造
self.register_buffer('eye', torch.eye(num_nodes))
def forward(self, x, node_embeddings):
supp1 = self.eye.to(node_embeddings[0].device)
filt = self.fc(x)
nodevec = torch.tanh(node_embeddings[0] * filt)
supp2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supp1)
x_g = torch.stack([torch.einsum("nm,bmc->bnc", supp1, x),
torch.einsum("bnm,bmc->bnc", supp2, x)], dim=1)
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
bias = torch.matmul(node_embeddings[1], self.bias_pool)
return torch.einsum('bnki,nkio->bno', x_g.permute(0, 2, 1, 3), weights) + bias
@staticmethod
def get_laplacian(graph, I, normalize=True):
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
return torch.matmul(torch.matmul(D_inv, graph), D_inv) if normalize else torch.matmul(
torch.matmul(D_inv, graph + I), D_inv)