438 lines
16 KiB
Python
438 lines
16 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from torch.nn import init
|
||
from data.get_adj import get_adj
|
||
import numbers
|
||
|
||
|
||
# --- 基础算子 ---
|
||
class NConv(nn.Module):
|
||
def forward(self, x, adj):
|
||
return torch.einsum("ncwl,vw->ncvl", (x, adj)).contiguous()
|
||
|
||
|
||
class DyNconv(nn.Module):
|
||
def forward(self, x, adj):
|
||
return torch.einsum("ncvl,nvwl->ncwl", (x, adj)).contiguous()
|
||
|
||
|
||
class Linear(nn.Module):
|
||
def __init__(self, c_in, c_out, bias=True):
|
||
super().__init__()
|
||
self.mlp = nn.Conv2d(c_in, c_out, kernel_size=1, bias=bias)
|
||
|
||
def forward(self, x):
|
||
return self.mlp(x)
|
||
|
||
|
||
class Prop(nn.Module):
|
||
def __init__(self, c_in, c_out, gdep, dropout, alpha):
|
||
super().__init__()
|
||
self.nconv = NConv()
|
||
self.mlp = Linear(c_in, c_out)
|
||
self.gdep, self.dropout, self.alpha = gdep, dropout, alpha
|
||
|
||
def forward(self, x, adj):
|
||
adj = adj + torch.eye(adj.size(0), device=x.device)
|
||
d = adj.sum(1)
|
||
a = adj / d.view(-1, 1)
|
||
h = x
|
||
for _ in range(self.gdep):
|
||
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a)
|
||
return self.mlp(h)
|
||
|
||
|
||
class MixProp(nn.Module):
|
||
def __init__(self, c_in, c_out, gdep, dropout, alpha):
|
||
super().__init__()
|
||
self.nconv = NConv()
|
||
self.mlp = Linear((gdep + 1) * c_in, c_out)
|
||
self.gdep, self.dropout, self.alpha = gdep, dropout, alpha
|
||
|
||
def forward(self, x, adj):
|
||
adj = adj + torch.eye(adj.size(0), device=x.device)
|
||
d = adj.sum(1)
|
||
a = adj / d.view(-1, 1)
|
||
out = [x]
|
||
h = x
|
||
for _ in range(self.gdep):
|
||
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a)
|
||
out.append(h)
|
||
return self.mlp(torch.cat(out, dim=1))
|
||
|
||
|
||
class DyMixprop(nn.Module):
|
||
def __init__(self, c_in, c_out, gdep, dropout, alpha):
|
||
super().__init__()
|
||
self.nconv = DyNconv()
|
||
self.mlp1 = Linear((gdep + 1) * c_in, c_out)
|
||
self.mlp2 = Linear((gdep + 1) * c_in, c_out)
|
||
self.gdep, self.dropout, self.alpha = gdep, dropout, alpha
|
||
self.lin1, self.lin2 = Linear(c_in, c_in), Linear(c_in, c_in)
|
||
|
||
def forward(self, x):
|
||
x1 = torch.tanh(self.lin1(x))
|
||
x2 = torch.tanh(self.lin2(x))
|
||
adj = self.nconv(x1.transpose(2, 1), x2)
|
||
adj0 = torch.softmax(adj, dim=2)
|
||
adj1 = torch.softmax(adj.transpose(2, 1), dim=2)
|
||
# 两条分支
|
||
out1, out2 = [x], [x]
|
||
h = x
|
||
for _ in range(self.gdep):
|
||
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj0)
|
||
out1.append(h)
|
||
h = x
|
||
for _ in range(self.gdep):
|
||
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1)
|
||
out2.append(h)
|
||
return self.mlp1(torch.cat(out1, dim=1)) + self.mlp2(torch.cat(out2, dim=1))
|
||
|
||
|
||
class DilatedInception(nn.Module):
|
||
def __init__(self, cin, cout, dilation_factor=2):
|
||
super().__init__()
|
||
self.kernels = [2, 3, 6, 7]
|
||
cout_each = int(cout / len(self.kernels))
|
||
self.convs = nn.ModuleList(
|
||
[
|
||
nn.Conv2d(
|
||
cin, cout_each, kernel_size=(1, k), dilation=(1, dilation_factor)
|
||
)
|
||
for k in self.kernels
|
||
]
|
||
)
|
||
|
||
def forward(self, x):
|
||
outs = [conv(x)[..., -self.convs[-1](x).size(3) :] for conv in self.convs]
|
||
return torch.cat(outs, dim=1)
|
||
|
||
|
||
class GraphConstructor(nn.Module):
|
||
def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None):
|
||
super().__init__()
|
||
self.nnodes, self.k, self.dim, self.alpha, self.device = (
|
||
nnodes,
|
||
k,
|
||
dim,
|
||
alpha,
|
||
device,
|
||
)
|
||
self.static_feat = static_feat
|
||
if static_feat is not None:
|
||
xd = static_feat.shape[1]
|
||
self.lin1, self.lin2 = nn.Linear(xd, dim), nn.Linear(xd, dim)
|
||
else:
|
||
self.emb1 = nn.Embedding(nnodes, dim)
|
||
self.emb2 = nn.Embedding(nnodes, dim)
|
||
self.lin1, self.lin2 = nn.Linear(dim, dim), nn.Linear(dim, dim)
|
||
|
||
def forward(self, idx):
|
||
if self.static_feat is None:
|
||
vec1, vec2 = self.emb1(idx), self.emb2(idx)
|
||
else:
|
||
vec1 = vec2 = self.static_feat[idx, :]
|
||
vec1 = torch.tanh(self.alpha * self.lin1(vec1))
|
||
vec2 = torch.tanh(self.alpha * self.lin2(vec2))
|
||
a = torch.mm(vec1, vec2.transpose(1, 0)) - torch.mm(vec2, vec1.transpose(1, 0))
|
||
adj = F.relu(torch.tanh(self.alpha * a))
|
||
mask = torch.zeros(idx.size(0), idx.size(0), device=self.device)
|
||
s1, t1 = adj.topk(self.k, 1)
|
||
mask.scatter_(1, t1, s1.new_ones(s1.size()))
|
||
return adj * mask
|
||
|
||
|
||
class LayerNorm(nn.Module):
|
||
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
||
|
||
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
|
||
super().__init__()
|
||
if isinstance(normalized_shape, numbers.Integral):
|
||
normalized_shape = (normalized_shape,)
|
||
self.normalized_shape, self.eps, self.elementwise_affine = (
|
||
tuple(normalized_shape),
|
||
eps,
|
||
elementwise_affine,
|
||
)
|
||
if elementwise_affine:
|
||
self.weight = nn.Parameter(torch.Tensor(*normalized_shape))
|
||
self.bias = nn.Parameter(torch.Tensor(*normalized_shape))
|
||
init.ones_(self.weight)
|
||
init.zeros_(self.bias)
|
||
else:
|
||
self.register_parameter("weight", None)
|
||
self.register_parameter("bias", None)
|
||
|
||
def forward(self, x, idx):
|
||
if self.elementwise_affine:
|
||
return F.layer_norm(
|
||
x,
|
||
tuple(x.shape[1:]),
|
||
self.weight[:, idx, :],
|
||
self.bias[:, idx, :],
|
||
self.eps,
|
||
)
|
||
else:
|
||
return F.layer_norm(x, tuple(x.shape[1:]), self.weight, self.bias, self.eps)
|
||
|
||
def extra_repr(self):
|
||
return f"{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}"
|
||
|
||
|
||
# --- 合并后的模型类,支持 teacher 与 stmlp 两种分支 ---
|
||
class STMLP(nn.Module):
|
||
def __init__(self, args):
|
||
super().__init__()
|
||
# 参数从字典中读取
|
||
self.adj_mx = get_adj(args)
|
||
self.num_nodes = args["num_nodes"]
|
||
self.feature_dim = args["input_dim"]
|
||
|
||
self.input_window = args["input_window"]
|
||
self.output_window = args["output_window"]
|
||
self.output_dim = args["output_dim"]
|
||
self.device = args["device"]
|
||
|
||
self.gcn_true = args["gcn_true"]
|
||
self.buildA_true = args["buildA_true"]
|
||
self.gcn_depth = args["gcn_depth"]
|
||
self.dropout = args["dropout"]
|
||
self.subgraph_size = args["subgraph_size"]
|
||
self.node_dim = args["node_dim"]
|
||
self.dilation_exponential = args["dilation_exponential"]
|
||
|
||
self.conv_channels = args["conv_channels"]
|
||
self.residual_channels = args["residual_channels"]
|
||
self.skip_channels = args["skip_channels"]
|
||
self.end_channels = args["end_channels"]
|
||
|
||
self.layers = args["layers"]
|
||
self.propalpha = args["propalpha"]
|
||
self.tanhalpha = args["tanhalpha"]
|
||
self.layer_norm_affline = args["layer_norm_affline"]
|
||
|
||
self.model_type = args["model_type"] # 'teacher' 或 'stmlp'
|
||
self.idx = torch.arange(self.num_nodes).to(self.device)
|
||
self.predefined_A = (
|
||
None
|
||
if self.adj_mx is None
|
||
else (torch.tensor(self.adj_mx) - torch.eye(self.num_nodes)).to(self.device)
|
||
)
|
||
self.static_feat = None
|
||
|
||
# transformer(保留原有结构)
|
||
self.encoder_layer = nn.TransformerEncoderLayer(
|
||
d_model=12, nhead=4, batch_first=True
|
||
)
|
||
self.transformer_encoder = nn.TransformerEncoder(
|
||
self.encoder_layer, num_layers=3
|
||
)
|
||
|
||
# 构建各层
|
||
self.start_conv = nn.Conv2d(
|
||
self.feature_dim, self.residual_channels, kernel_size=1
|
||
)
|
||
self.gc = GraphConstructor(
|
||
self.num_nodes,
|
||
self.subgraph_size,
|
||
self.node_dim,
|
||
self.device,
|
||
alpha=self.tanhalpha,
|
||
static_feat=self.static_feat,
|
||
)
|
||
# 计算 receptive_field
|
||
kernel_size = 7
|
||
if self.dilation_exponential > 1:
|
||
self.receptive_field = int(
|
||
self.output_dim
|
||
+ (kernel_size - 1)
|
||
* (self.dilation_exponential**self.layers - 1)
|
||
/ (self.dilation_exponential - 1)
|
||
)
|
||
else:
|
||
self.receptive_field = self.layers * (kernel_size - 1) + self.output_dim
|
||
|
||
self.filter_convs = nn.ModuleList()
|
||
self.gate_convs = nn.ModuleList()
|
||
self.residual_convs = nn.ModuleList()
|
||
self.skip_convs = nn.ModuleList()
|
||
self.norm = nn.ModuleList()
|
||
self.stu_mlp = nn.ModuleList(
|
||
[
|
||
nn.Sequential(nn.Linear(c, c), nn.Linear(c, c), nn.Linear(c, c))
|
||
for c in [13, 7, 1]
|
||
]
|
||
)
|
||
if self.gcn_true:
|
||
self.gconv1 = nn.ModuleList()
|
||
self.gconv2 = nn.ModuleList()
|
||
|
||
new_dilation = 1
|
||
for i in range(1):
|
||
rf_size_i = (
|
||
int(
|
||
1
|
||
+ i
|
||
* (kernel_size - 1)
|
||
* (self.dilation_exponential**self.layers - 1)
|
||
/ (self.dilation_exponential - 1)
|
||
)
|
||
if self.dilation_exponential > 1
|
||
else i * self.layers * (kernel_size - 1) + 1
|
||
)
|
||
for j in range(1, self.layers + 1):
|
||
rf_size_j = (
|
||
int(
|
||
rf_size_i
|
||
+ (kernel_size - 1)
|
||
* (self.dilation_exponential**j - 1)
|
||
/ (self.dilation_exponential - 1)
|
||
)
|
||
if self.dilation_exponential > 1
|
||
else rf_size_i + j * (kernel_size - 1)
|
||
)
|
||
self.filter_convs.append(
|
||
DilatedInception(
|
||
self.residual_channels,
|
||
self.conv_channels,
|
||
dilation_factor=new_dilation,
|
||
)
|
||
)
|
||
self.gate_convs.append(
|
||
DilatedInception(
|
||
self.residual_channels,
|
||
self.conv_channels,
|
||
dilation_factor=new_dilation,
|
||
)
|
||
)
|
||
self.residual_convs.append(
|
||
nn.Conv2d(self.conv_channels, self.residual_channels, kernel_size=1)
|
||
)
|
||
k_size = (
|
||
(1, self.input_window - rf_size_j + 1)
|
||
if self.input_window > self.receptive_field
|
||
else (1, self.receptive_field - rf_size_j + 1)
|
||
)
|
||
self.skip_convs.append(
|
||
nn.Conv2d(
|
||
self.conv_channels, self.skip_channels, kernel_size=k_size
|
||
)
|
||
)
|
||
if self.gcn_true:
|
||
self.gconv1.append(
|
||
MixProp(
|
||
self.conv_channels,
|
||
self.residual_channels,
|
||
self.gcn_depth,
|
||
self.dropout,
|
||
self.propalpha,
|
||
)
|
||
)
|
||
self.gconv2.append(
|
||
MixProp(
|
||
self.conv_channels,
|
||
self.residual_channels,
|
||
self.gcn_depth,
|
||
self.dropout,
|
||
self.propalpha,
|
||
)
|
||
)
|
||
norm_size = (
|
||
(
|
||
self.residual_channels,
|
||
self.num_nodes,
|
||
self.input_window - rf_size_j + 1,
|
||
)
|
||
if self.input_window > self.receptive_field
|
||
else (
|
||
self.residual_channels,
|
||
self.num_nodes,
|
||
self.receptive_field - rf_size_j + 1,
|
||
)
|
||
)
|
||
self.norm.append(
|
||
LayerNorm(norm_size, elementwise_affine=self.layer_norm_affline)
|
||
)
|
||
new_dilation *= self.dilation_exponential
|
||
|
||
self.end_conv_1 = nn.Conv2d(
|
||
self.skip_channels, self.end_channels, kernel_size=1, bias=True
|
||
)
|
||
self.end_conv_2 = nn.Conv2d(
|
||
self.end_channels, self.output_window, kernel_size=1, bias=True
|
||
)
|
||
k0 = (
|
||
(1, self.input_window)
|
||
if self.input_window > self.receptive_field
|
||
else (1, self.receptive_field)
|
||
)
|
||
self.skip0 = nn.Conv2d(
|
||
self.feature_dim, self.skip_channels, kernel_size=k0, bias=True
|
||
)
|
||
kE = (
|
||
(1, self.input_window - self.receptive_field + 1)
|
||
if self.input_window > self.receptive_field
|
||
else (1, 1)
|
||
)
|
||
self.skipE = nn.Conv2d(
|
||
self.residual_channels, self.skip_channels, kernel_size=kE, bias=True
|
||
)
|
||
# 最后输出分支,根据模型类型选择不同的头
|
||
if self.model_type == "teacher":
|
||
self.tt_linear1 = nn.Linear(self.residual_channels, self.input_window)
|
||
self.tt_linear2 = nn.Linear(1, 32)
|
||
self.ss_linear1 = nn.Linear(self.residual_channels, self.input_window)
|
||
self.ss_linear2 = nn.Linear(1, 32)
|
||
else: # stmlp
|
||
self.out_linear1 = nn.Linear(self.residual_channels, self.input_window)
|
||
self.out_linear2 = nn.Linear(1, 32)
|
||
|
||
def forward(self, source, idx=None):
|
||
source = source[..., 0:1]
|
||
sout, tout = [], []
|
||
inputs = source.transpose(1, 3)
|
||
assert inputs.size(3) == self.input_window, "input sequence length mismatch"
|
||
if self.input_window < self.receptive_field:
|
||
inputs = F.pad(inputs, (self.receptive_field - self.input_window, 0, 0, 0))
|
||
if self.gcn_true:
|
||
adp = (
|
||
self.gc(self.idx if idx is None else idx)
|
||
if self.buildA_true
|
||
else self.predefined_A
|
||
)
|
||
x = self.start_conv(inputs)
|
||
skip = self.skip0(F.dropout(inputs, self.dropout, training=self.training))
|
||
for i in range(self.layers):
|
||
residual = x
|
||
filters = torch.tanh(self.filter_convs[i](x))
|
||
gate = torch.sigmoid(self.gate_convs[i](x))
|
||
x = F.dropout(filters * gate, self.dropout, training=self.training)
|
||
tout.append(x)
|
||
s = self.skip_convs[i](x)
|
||
skip = s + skip
|
||
if self.gcn_true:
|
||
x = self.gconv1[i](x, adp) + self.gconv2[i](x, adp.transpose(1, 0))
|
||
else:
|
||
x = self.stu_mlp[i](x)
|
||
x = x + residual[:, :, :, -x.size(3) :]
|
||
x = self.norm[i](x, self.idx if idx is None else idx)
|
||
sout.append(x)
|
||
skip = self.skipE(x) + skip
|
||
x = F.relu(skip)
|
||
x = F.relu(self.end_conv_1(x))
|
||
x = self.end_conv_2(x)
|
||
if self.model_type == "teacher":
|
||
ttout = self.tt_linear2(
|
||
self.tt_linear1(tout[-1].transpose(1, 3)).transpose(1, 3)
|
||
)
|
||
ssout = self.ss_linear2(
|
||
self.ss_linear1(sout[-1].transpose(1, 3)).transpose(1, 3)
|
||
)
|
||
return x, ttout, ssout
|
||
else:
|
||
x_ = self.out_linear2(
|
||
self.out_linear1(tout[-1].transpose(1, 3)).transpose(1, 3)
|
||
)
|
||
return x, x_, x
|