308 lines
14 KiB
Python
308 lines
14 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from torch.nn import init
|
||
from data.get_adj import get_adj
|
||
import numbers
|
||
|
||
|
||
# --- 基础算子 ---
|
||
class NConv(nn.Module):
|
||
def forward(self, x, adj):
|
||
return torch.einsum('ncwl,vw->ncvl', (x, adj)).contiguous()
|
||
|
||
|
||
class DyNconv(nn.Module):
|
||
def forward(self, x, adj):
|
||
return torch.einsum('ncvl,nvwl->ncwl', (x, adj)).contiguous()
|
||
|
||
|
||
class Linear(nn.Module):
|
||
def __init__(self, c_in, c_out, bias=True):
|
||
super().__init__()
|
||
self.mlp = nn.Conv2d(c_in, c_out, kernel_size=1, bias=bias)
|
||
|
||
def forward(self, x):
|
||
return self.mlp(x)
|
||
|
||
|
||
class Prop(nn.Module):
|
||
def __init__(self, c_in, c_out, gdep, dropout, alpha):
|
||
super().__init__()
|
||
self.nconv = NConv()
|
||
self.mlp = Linear(c_in, c_out)
|
||
self.gdep, self.dropout, self.alpha = gdep, dropout, alpha
|
||
|
||
def forward(self, x, adj):
|
||
adj = adj + torch.eye(adj.size(0), device=x.device)
|
||
d = adj.sum(1)
|
||
a = adj / d.view(-1, 1)
|
||
h = x
|
||
for _ in range(self.gdep):
|
||
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a)
|
||
return self.mlp(h)
|
||
|
||
|
||
class MixProp(nn.Module):
|
||
def __init__(self, c_in, c_out, gdep, dropout, alpha):
|
||
super().__init__()
|
||
self.nconv = NConv()
|
||
self.mlp = Linear((gdep + 1) * c_in, c_out)
|
||
self.gdep, self.dropout, self.alpha = gdep, dropout, alpha
|
||
|
||
def forward(self, x, adj):
|
||
adj = adj + torch.eye(adj.size(0), device=x.device)
|
||
d = adj.sum(1)
|
||
a = adj / d.view(-1, 1)
|
||
out = [x]
|
||
h = x
|
||
for _ in range(self.gdep):
|
||
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a)
|
||
out.append(h)
|
||
return self.mlp(torch.cat(out, dim=1))
|
||
|
||
|
||
class DyMixprop(nn.Module):
|
||
def __init__(self, c_in, c_out, gdep, dropout, alpha):
|
||
super().__init__()
|
||
self.nconv = DyNconv()
|
||
self.mlp1 = Linear((gdep + 1) * c_in, c_out)
|
||
self.mlp2 = Linear((gdep + 1) * c_in, c_out)
|
||
self.gdep, self.dropout, self.alpha = gdep, dropout, alpha
|
||
self.lin1, self.lin2 = Linear(c_in, c_in), Linear(c_in, c_in)
|
||
|
||
def forward(self, x):
|
||
x1 = torch.tanh(self.lin1(x))
|
||
x2 = torch.tanh(self.lin2(x))
|
||
adj = self.nconv(x1.transpose(2, 1), x2)
|
||
adj0 = torch.softmax(adj, dim=2)
|
||
adj1 = torch.softmax(adj.transpose(2, 1), dim=2)
|
||
# 两条分支
|
||
out1, out2 = [x], [x]
|
||
h = x
|
||
for _ in range(self.gdep):
|
||
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj0)
|
||
out1.append(h)
|
||
h = x
|
||
for _ in range(self.gdep):
|
||
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1)
|
||
out2.append(h)
|
||
return self.mlp1(torch.cat(out1, dim=1)) + self.mlp2(torch.cat(out2, dim=1))
|
||
|
||
|
||
class DilatedInception(nn.Module):
|
||
def __init__(self, cin, cout, dilation_factor=2):
|
||
super().__init__()
|
||
self.kernels = [2, 3, 6, 7]
|
||
cout_each = int(cout / len(self.kernels))
|
||
self.convs = nn.ModuleList([nn.Conv2d(cin, cout_each, kernel_size=(1, k), dilation=(1, dilation_factor))
|
||
for k in self.kernels])
|
||
|
||
def forward(self, x):
|
||
outs = [conv(x)[..., -self.convs[-1](x).size(3):] for conv in self.convs]
|
||
return torch.cat(outs, dim=1)
|
||
|
||
|
||
class GraphConstructor(nn.Module):
|
||
def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None):
|
||
super().__init__()
|
||
self.nnodes, self.k, self.dim, self.alpha, self.device = nnodes, k, dim, alpha, device
|
||
self.static_feat = static_feat
|
||
if static_feat is not None:
|
||
xd = static_feat.shape[1]
|
||
self.lin1, self.lin2 = nn.Linear(xd, dim), nn.Linear(xd, dim)
|
||
else:
|
||
self.emb1 = nn.Embedding(nnodes, dim)
|
||
self.emb2 = nn.Embedding(nnodes, dim)
|
||
self.lin1, self.lin2 = nn.Linear(dim, dim), nn.Linear(dim, dim)
|
||
|
||
def forward(self, idx):
|
||
if self.static_feat is None:
|
||
vec1, vec2 = self.emb1(idx), self.emb2(idx)
|
||
else:
|
||
vec1 = vec2 = self.static_feat[idx, :]
|
||
vec1 = torch.tanh(self.alpha * self.lin1(vec1))
|
||
vec2 = torch.tanh(self.alpha * self.lin2(vec2))
|
||
a = torch.mm(vec1, vec2.transpose(1, 0)) - torch.mm(vec2, vec1.transpose(1, 0))
|
||
adj = F.relu(torch.tanh(self.alpha * a))
|
||
mask = torch.zeros(idx.size(0), idx.size(0), device=self.device)
|
||
s1, t1 = adj.topk(self.k, 1)
|
||
mask.scatter_(1, t1, s1.new_ones(s1.size()))
|
||
return adj * mask
|
||
|
||
|
||
class LayerNorm(nn.Module):
|
||
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
|
||
|
||
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
|
||
super().__init__()
|
||
if isinstance(normalized_shape, numbers.Integral):
|
||
normalized_shape = (normalized_shape,)
|
||
self.normalized_shape, self.eps, self.elementwise_affine = tuple(normalized_shape), eps, elementwise_affine
|
||
if elementwise_affine:
|
||
self.weight = nn.Parameter(torch.Tensor(*normalized_shape))
|
||
self.bias = nn.Parameter(torch.Tensor(*normalized_shape))
|
||
init.ones_(self.weight);
|
||
init.zeros_(self.bias)
|
||
else:
|
||
self.register_parameter('weight', None)
|
||
self.register_parameter('bias', None)
|
||
|
||
def forward(self, x, idx):
|
||
if self.elementwise_affine:
|
||
return F.layer_norm(x, tuple(x.shape[1:]), self.weight[:, idx, :], self.bias[:, idx, :], self.eps)
|
||
else:
|
||
return F.layer_norm(x, tuple(x.shape[1:]), self.weight, self.bias, self.eps)
|
||
|
||
def extra_repr(self):
|
||
return f'{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
|
||
|
||
|
||
# --- 合并后的模型类,支持 teacher 与 stmlp 两种分支 ---
|
||
class STMLP(nn.Module):
|
||
def __init__(self, args):
|
||
super().__init__()
|
||
# 参数从字典中读取
|
||
self.adj_mx = get_adj(args)
|
||
self.num_nodes = args['num_nodes']
|
||
self.feature_dim = args['input_dim']
|
||
|
||
self.input_window = args['input_window']
|
||
self.output_window = args['output_window']
|
||
self.output_dim = args['output_dim']
|
||
self.device = args['device']
|
||
|
||
self.gcn_true = args['gcn_true']
|
||
self.buildA_true = args['buildA_true']
|
||
self.gcn_depth = args['gcn_depth']
|
||
self.dropout = args['dropout']
|
||
self.subgraph_size = args['subgraph_size']
|
||
self.node_dim = args['node_dim']
|
||
self.dilation_exponential = args['dilation_exponential']
|
||
|
||
self.conv_channels = args['conv_channels']
|
||
self.residual_channels = args['residual_channels']
|
||
self.skip_channels = args['skip_channels']
|
||
self.end_channels = args['end_channels']
|
||
|
||
self.layers = args['layers']
|
||
self.propalpha = args['propalpha']
|
||
self.tanhalpha = args['tanhalpha']
|
||
self.layer_norm_affline = args['layer_norm_affline']
|
||
|
||
self.model_type = args['model_type'] # 'teacher' 或 'stmlp'
|
||
self.idx = torch.arange(self.num_nodes).to(self.device)
|
||
self.predefined_A = None if self.adj_mx is None else (torch.tensor(self.adj_mx) - torch.eye(self.num_nodes)).to(
|
||
self.device)
|
||
self.static_feat = None
|
||
|
||
# transformer(保留原有结构)
|
||
self.encoder_layer = nn.TransformerEncoderLayer(d_model=12, nhead=4, batch_first=True)
|
||
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=3)
|
||
|
||
# 构建各层
|
||
self.start_conv = nn.Conv2d(self.feature_dim, self.residual_channels, kernel_size=1)
|
||
self.gc = GraphConstructor(self.num_nodes, self.subgraph_size, self.node_dim, self.device, alpha=self.tanhalpha,
|
||
static_feat=self.static_feat)
|
||
# 计算 receptive_field
|
||
kernel_size = 7
|
||
if self.dilation_exponential > 1:
|
||
self.receptive_field = int(
|
||
self.output_dim + (kernel_size - 1) * (self.dilation_exponential ** self.layers - 1) / (
|
||
self.dilation_exponential - 1))
|
||
else:
|
||
self.receptive_field = self.layers * (kernel_size - 1) + self.output_dim
|
||
|
||
self.filter_convs = nn.ModuleList()
|
||
self.gate_convs = nn.ModuleList()
|
||
self.residual_convs = nn.ModuleList()
|
||
self.skip_convs = nn.ModuleList()
|
||
self.norm = nn.ModuleList()
|
||
self.stu_mlp = nn.ModuleList([nn.Sequential(nn.Linear(c, c), nn.Linear(c, c), nn.Linear(c, c))
|
||
for c in [13, 7, 1]])
|
||
if self.gcn_true:
|
||
self.gconv1 = nn.ModuleList()
|
||
self.gconv2 = nn.ModuleList()
|
||
|
||
new_dilation = 1
|
||
for i in range(1):
|
||
rf_size_i = int(1 + i * (kernel_size - 1) * (self.dilation_exponential ** self.layers - 1) / (
|
||
self.dilation_exponential - 1)) if self.dilation_exponential > 1 else i * self.layers * (
|
||
kernel_size - 1) + 1
|
||
for j in range(1, self.layers + 1):
|
||
rf_size_j = int(rf_size_i + (kernel_size - 1) * (self.dilation_exponential ** j - 1) / (
|
||
self.dilation_exponential - 1)) if self.dilation_exponential > 1 else rf_size_i + j * (
|
||
kernel_size - 1)
|
||
self.filter_convs.append(
|
||
DilatedInception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation))
|
||
self.gate_convs.append(
|
||
DilatedInception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation))
|
||
self.residual_convs.append(nn.Conv2d(self.conv_channels, self.residual_channels, kernel_size=1))
|
||
k_size = (1, self.input_window - rf_size_j + 1) if self.input_window > self.receptive_field else (
|
||
1, self.receptive_field - rf_size_j + 1)
|
||
self.skip_convs.append(nn.Conv2d(self.conv_channels, self.skip_channels, kernel_size=k_size))
|
||
if self.gcn_true:
|
||
self.gconv1.append(MixProp(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout,
|
||
self.propalpha))
|
||
self.gconv2.append(MixProp(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout,
|
||
self.propalpha))
|
||
norm_size = (self.residual_channels, self.num_nodes,
|
||
self.input_window - rf_size_j + 1) if self.input_window > self.receptive_field else (
|
||
self.residual_channels, self.num_nodes, self.receptive_field - rf_size_j + 1)
|
||
self.norm.append(LayerNorm(norm_size, elementwise_affine=self.layer_norm_affline))
|
||
new_dilation *= self.dilation_exponential
|
||
|
||
self.end_conv_1 = nn.Conv2d(self.skip_channels, self.end_channels, kernel_size=1, bias=True)
|
||
self.end_conv_2 = nn.Conv2d(self.end_channels, self.output_window, kernel_size=1, bias=True)
|
||
k0 = (1, self.input_window) if self.input_window > self.receptive_field else (1, self.receptive_field)
|
||
self.skip0 = nn.Conv2d(self.feature_dim, self.skip_channels, kernel_size=k0, bias=True)
|
||
kE = (1, self.input_window - self.receptive_field + 1) if self.input_window > self.receptive_field else (1, 1)
|
||
self.skipE = nn.Conv2d(self.residual_channels, self.skip_channels, kernel_size=kE, bias=True)
|
||
# 最后输出分支,根据模型类型选择不同的头
|
||
if self.model_type == 'teacher':
|
||
self.tt_linear1 = nn.Linear(self.residual_channels, self.input_window)
|
||
self.tt_linear2 = nn.Linear(1, 32)
|
||
self.ss_linear1 = nn.Linear(self.residual_channels, self.input_window)
|
||
self.ss_linear2 = nn.Linear(1, 32)
|
||
else: # stmlp
|
||
self.out_linear1 = nn.Linear(self.residual_channels, self.input_window)
|
||
self.out_linear2 = nn.Linear(1, 32)
|
||
|
||
def forward(self, source, idx=None):
|
||
source = source[..., 0:1]
|
||
sout, tout = [], []
|
||
inputs = source.transpose(1, 3)
|
||
assert inputs.size(3) == self.input_window, 'input sequence length mismatch'
|
||
if self.input_window < self.receptive_field:
|
||
inputs = F.pad(inputs, (self.receptive_field - self.input_window, 0, 0, 0))
|
||
if self.gcn_true:
|
||
adp = self.gc(self.idx if idx is None else idx) if self.buildA_true else self.predefined_A
|
||
x = self.start_conv(inputs)
|
||
skip = self.skip0(F.dropout(inputs, self.dropout, training=self.training))
|
||
for i in range(self.layers):
|
||
residual = x
|
||
filters = torch.tanh(self.filter_convs[i](x))
|
||
gate = torch.sigmoid(self.gate_convs[i](x))
|
||
x = F.dropout(filters * gate, self.dropout, training=self.training)
|
||
tout.append(x)
|
||
s = self.skip_convs[i](x)
|
||
skip = s + skip
|
||
if self.gcn_true:
|
||
x = self.gconv1[i](x, adp) + self.gconv2[i](x, adp.transpose(1, 0))
|
||
else:
|
||
x = self.stu_mlp[i](x)
|
||
x = x + residual[:, :, :, -x.size(3):]
|
||
x = self.norm[i](x, self.idx if idx is None else idx)
|
||
sout.append(x)
|
||
skip = self.skipE(x) + skip
|
||
x = F.relu(skip)
|
||
x = F.relu(self.end_conv_1(x))
|
||
x = self.end_conv_2(x)
|
||
if self.model_type == 'teacher':
|
||
ttout = self.tt_linear2(self.tt_linear1(tout[-1].transpose(1, 3)).transpose(1, 3))
|
||
ssout = self.ss_linear2(self.ss_linear1(sout[-1].transpose(1, 3)).transpose(1, 3))
|
||
return x, ttout, ssout
|
||
else:
|
||
x_ = self.out_linear2(self.out_linear1(tout[-1].transpose(1, 3)).transpose(1, 3))
|
||
return x, x_, x
|