179 lines
6.8 KiB
Python
Executable File
179 lines
6.8 KiB
Python
Executable File
import torch, torch.nn as nn, torch.nn.functional as F
|
||
from collections import OrderedDict
|
||
|
||
|
||
class DGCRM(nn.Module):
|
||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
||
super().__init__()
|
||
self.node_num, self.input_dim, self.num_layers = node_num, dim_in, num_layers
|
||
self.cells = nn.ModuleList(
|
||
[
|
||
DDGCRNCell(
|
||
node_num, dim_in if i == 0 else dim_out, dim_out, cheb_k, embed_dim
|
||
)
|
||
for i in range(num_layers)
|
||
]
|
||
)
|
||
|
||
def forward(self, x, init_state, node_embeddings):
|
||
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
||
for i in range(self.num_layers):
|
||
state, inner = init_state[i].to(x.device), []
|
||
for t in range(x.shape[1]):
|
||
state = self.cells[i](
|
||
x[:, t, :, :],
|
||
state,
|
||
[node_embeddings[0][:, t, :, :], node_embeddings[1]],
|
||
)
|
||
inner.append(state)
|
||
init_state[i] = state
|
||
x = torch.stack(inner, dim=1)
|
||
return x, init_state
|
||
|
||
def init_hidden(self, bs):
|
||
return torch.stack([cell.init_hidden_state(bs) for cell in self.cells], dim=0)
|
||
|
||
|
||
class EXPB(nn.Module):
|
||
def __init__(self, args):
|
||
super().__init__()
|
||
self.patch_size = args.get("patch_size", 1)
|
||
self.num_node, self.input_dim, self.hidden_dim = (
|
||
args["num_nodes"],
|
||
args["input_dim"],
|
||
args["rnn_units"],
|
||
)
|
||
self.output_dim, self.horizon, self.num_layers = (
|
||
args["output_dim"],
|
||
args["horizon"],
|
||
args["num_layers"],
|
||
)
|
||
self.use_day, self.use_week = args["use_day"], args["use_week"]
|
||
self.node_embeddings1 = nn.Parameter(
|
||
torch.randn(self.num_node, args["embed_dim"])
|
||
)
|
||
self.T_i_D_emb = nn.Parameter(torch.empty(288, args["embed_dim"]))
|
||
self.D_i_W_emb = nn.Parameter(torch.empty(7, args["embed_dim"]))
|
||
self.drop = nn.Dropout(0.1)
|
||
self.encoder = DGCRM(
|
||
self.num_node,
|
||
self.input_dim,
|
||
self.hidden_dim,
|
||
args["cheb_order"],
|
||
args["embed_dim"],
|
||
self.num_layers,
|
||
)
|
||
self.base_conv = nn.Conv2d(
|
||
1, self.horizon * self.output_dim, (1, self.hidden_dim)
|
||
)
|
||
self.res_conv = nn.Conv2d(
|
||
1, self.horizon * self.output_dim, (1, self.hidden_dim + 1)
|
||
)
|
||
|
||
def forward(self, source):
|
||
# source: (B, T, N, D_total),第0通道为主观测,第1、2通道为时间编码
|
||
B, T, N, D_total = source.shape
|
||
p = self.patch_size
|
||
num_patches = T // p
|
||
source = source[:, : num_patches * p, :, :].view(B, num_patches, p, N, D_total)
|
||
# 对主观测通道取均值,并转置为 (B, num_patches, N, 1)
|
||
inp = source[..., 0].mean(dim=2, keepdim=True).permute(0, 1, 3, 2)
|
||
# 每个 patch 最后时刻的时间编码
|
||
time_day = source[:, :, -1, :, 1] # (B, num_patches, N)
|
||
time_week = source[:, :, -1, :, 2] # (B, num_patches, N)
|
||
patched_source = torch.cat(
|
||
[inp, time_day.unsqueeze(-1), time_week.unsqueeze(-1)], dim=-1
|
||
)
|
||
node_embed = self.node_embeddings1
|
||
if self.use_day:
|
||
node_embed = (
|
||
node_embed * self.T_i_D_emb[(patched_source[..., 1] * 288).long()]
|
||
)
|
||
if self.use_week:
|
||
node_embed = node_embed * self.D_i_W_emb[patched_source[..., 2].long()]
|
||
node_embeddings = [node_embed, self.node_embeddings1]
|
||
init = self.encoder.init_hidden(B)
|
||
enc_out, _ = self.encoder(inp, init, node_embeddings)
|
||
rep = self.drop(enc_out[:, -1:, :, :])
|
||
base = self.base_conv(rep)
|
||
res_in = torch.cat([rep, inp[:, -1:, :, :]], dim=-1)
|
||
res = self.res_conv(res_in)
|
||
out = base + res
|
||
out = (
|
||
out.squeeze(-1)
|
||
.view(B, self.horizon, self.output_dim, N)
|
||
.permute(0, 1, 3, 2)
|
||
)
|
||
return out
|
||
|
||
|
||
class DDGCRNCell(nn.Module):
|
||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
|
||
super().__init__()
|
||
self.node_num, self.hidden_dim = node_num, dim_out
|
||
self.gate = DGCN(dim_in + dim_out, 2 * dim_out, cheb_k, embed_dim, node_num)
|
||
self.update = DGCN(dim_in + dim_out, dim_out, cheb_k, embed_dim, node_num)
|
||
self.ln = nn.LayerNorm(dim_out)
|
||
|
||
def forward(self, x, state, node_embeddings):
|
||
inp = torch.cat((x, state), -1)
|
||
z_r = torch.sigmoid(self.gate(inp, node_embeddings))
|
||
z, r = torch.split(z_r, self.hidden_dim, -1)
|
||
hc = torch.tanh(self.update(torch.cat((x, z * state), -1), node_embeddings))
|
||
out = r * state + (1 - r) * hc
|
||
return self.ln(out)
|
||
|
||
def init_hidden_state(self, bs):
|
||
return torch.zeros(bs, self.node_num, self.hidden_dim)
|
||
|
||
|
||
class DGCN(nn.Module):
|
||
def __init__(self, dim_in, dim_out, cheb_k, embed_dim, num_nodes):
|
||
super().__init__()
|
||
self.cheb_k, self.embed_dim = cheb_k, embed_dim
|
||
self.weights_pool = nn.Parameter(
|
||
torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)
|
||
)
|
||
self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out))
|
||
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
|
||
self.bias = nn.Parameter(torch.FloatTensor(dim_out))
|
||
self.fc = nn.Sequential(
|
||
OrderedDict(
|
||
[
|
||
("fc1", nn.Linear(dim_in, 16)),
|
||
("sigmoid1", nn.Sigmoid()),
|
||
("fc2", nn.Linear(16, 2)),
|
||
("sigmoid2", nn.Sigmoid()),
|
||
("fc3", nn.Linear(2, embed_dim)),
|
||
]
|
||
)
|
||
)
|
||
self.register_buffer("eye", torch.eye(num_nodes))
|
||
|
||
def forward(self, x, node_embeddings):
|
||
supp1 = self.eye.to(node_embeddings[0].device)
|
||
filt = self.fc(x)
|
||
nodevec = torch.tanh(node_embeddings[0] * filt)
|
||
supp2 = self.get_laplacian(
|
||
F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supp1
|
||
)
|
||
x_g = torch.stack(
|
||
[
|
||
torch.einsum("nm,bmc->bnc", supp1, x),
|
||
torch.einsum("bnm,bmc->bnc", supp2, x),
|
||
],
|
||
dim=1,
|
||
)
|
||
weights = torch.einsum("nd,dkio->nkio", node_embeddings[1], self.weights_pool)
|
||
bias = torch.matmul(node_embeddings[1], self.bias_pool)
|
||
return torch.einsum("bnki,nkio->bno", x_g.permute(0, 2, 1, 3), weights) + bias
|
||
|
||
@staticmethod
|
||
def get_laplacian(graph, I, normalize=True):
|
||
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
|
||
return (
|
||
torch.matmul(torch.matmul(D_inv, graph), D_inv)
|
||
if normalize
|
||
else torch.matmul(torch.matmul(D_inv, graph + I), D_inv)
|
||
)
|