TrafficWheel/model/EXPB/EXP_b.py

129 lines
6.2 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch, torch.nn as nn, torch.nn.functional as F
from collections import OrderedDict
class DGCRM(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
super().__init__()
self.node_num, self.input_dim, self.num_layers = node_num, dim_in, num_layers
self.cells = nn.ModuleList([
DDGCRNCell(node_num, dim_in if i == 0 else dim_out, dim_out, cheb_k, embed_dim)
for i in range(num_layers)
])
def forward(self, x, init_state, node_embeddings):
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
for i in range(self.num_layers):
state, inner = init_state[i].to(x.device), []
for t in range(x.shape[1]):
state = self.cells[i](x[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]])
inner.append(state)
init_state[i] = state
x = torch.stack(inner, dim=1)
return x, init_state
def init_hidden(self, bs):
return torch.stack([cell.init_hidden_state(bs) for cell in self.cells], dim=0)
class EXPB(nn.Module):
def __init__(self, args):
super().__init__()
self.patch_size = args.get('patch_size', 1)
self.num_node, self.input_dim, self.hidden_dim = args['num_nodes'], args['input_dim'], args['rnn_units']
self.output_dim, self.horizon, self.num_layers = args['output_dim'], args['horizon'], args['num_layers']
self.use_day, self.use_week = args['use_day'], args['use_week']
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args['embed_dim']))
self.T_i_D_emb = nn.Parameter(torch.empty(288, args['embed_dim']))
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['embed_dim']))
self.drop = nn.Dropout(0.1)
self.encoder = DGCRM(self.num_node, self.input_dim, self.hidden_dim,
args['cheb_order'], args['embed_dim'], self.num_layers)
self.base_conv = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim))
self.res_conv = nn.Conv2d(1, self.horizon * self.output_dim, (1, self.hidden_dim + 1))
def forward(self, source):
# source: (B, T, N, D_total)第0通道为主观测第1、2通道为时间编码
B, T, N, D_total = source.shape
p = self.patch_size
num_patches = T // p
source = source[:, :num_patches * p, :, :].view(B, num_patches, p, N, D_total)
# 对主观测通道取均值,并转置为 (B, num_patches, N, 1)
inp = source[..., 0].mean(dim=2, keepdim=True).permute(0, 1, 3, 2)
# 每个 patch 最后时刻的时间编码
time_day = source[:, :, -1, :, 1] # (B, num_patches, N)
time_week = source[:, :, -1, :, 2] # (B, num_patches, N)
patched_source = torch.cat([inp, time_day.unsqueeze(-1), time_week.unsqueeze(-1)], dim=-1)
node_embed = self.node_embeddings1
if self.use_day:
node_embed = node_embed * self.T_i_D_emb[(patched_source[..., 1] * 288).long()]
if self.use_week:
node_embed = node_embed * self.D_i_W_emb[patched_source[..., 2].long()]
node_embeddings = [node_embed, self.node_embeddings1]
init = self.encoder.init_hidden(B)
enc_out, _ = self.encoder(inp, init, node_embeddings)
rep = self.drop(enc_out[:, -1:, :, :])
base = self.base_conv(rep)
res_in = torch.cat([rep, inp[:, -1:, :, :]], dim=-1)
res = self.res_conv(res_in)
out = base + res
out = out.squeeze(-1).view(B, self.horizon, self.output_dim, N).permute(0, 1, 3, 2)
return out
class DDGCRNCell(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
super().__init__()
self.node_num, self.hidden_dim = node_num, dim_out
self.gate = DGCN(dim_in + dim_out, 2 * dim_out, cheb_k, embed_dim, node_num)
self.update = DGCN(dim_in + dim_out, dim_out, cheb_k, embed_dim, node_num)
self.ln = nn.LayerNorm(dim_out)
def forward(self, x, state, node_embeddings):
inp = torch.cat((x, state), -1)
z_r = torch.sigmoid(self.gate(inp, node_embeddings))
z, r = torch.split(z_r, self.hidden_dim, -1)
hc = torch.tanh(self.update(torch.cat((x, z * state), -1), node_embeddings))
out = r * state + (1 - r) * hc
return self.ln(out)
def init_hidden_state(self, bs):
return torch.zeros(bs, self.node_num, self.hidden_dim)
class DGCN(nn.Module):
def __init__(self, dim_in, dim_out, cheb_k, embed_dim, num_nodes):
super().__init__()
self.cheb_k, self.embed_dim = cheb_k, embed_dim
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
self.fc = nn.Sequential(OrderedDict([
('fc1', nn.Linear(dim_in, 16)),
('sigmoid1', nn.Sigmoid()),
('fc2', nn.Linear(16, 2)),
('sigmoid2', nn.Sigmoid()),
('fc3', nn.Linear(2, embed_dim))
]))
self.register_buffer('eye', torch.eye(num_nodes))
def forward(self, x, node_embeddings):
supp1 = self.eye.to(node_embeddings[0].device)
filt = self.fc(x)
nodevec = torch.tanh(node_embeddings[0] * filt)
supp2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supp1)
x_g = torch.stack([
torch.einsum("nm,bmc->bnc", supp1, x),
torch.einsum("bnm,bmc->bnc", supp2, x)
], dim=1)
weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool)
bias = torch.matmul(node_embeddings[1], self.bias_pool)
return torch.einsum('bnki,nkio->bno', x_g.permute(0, 2, 1, 3), weights) + bias
@staticmethod
def get_laplacian(graph, I, normalize=True):
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
return torch.matmul(torch.matmul(D_inv, graph), D_inv) if normalize else torch.matmul(
torch.matmul(D_inv, graph + I), D_inv)