import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init from utils.get_adj import get_adj import numbers # --- 基础算子 --- class NConv(nn.Module): def forward(self, x, adj): return torch.einsum("ncwl,vw->ncvl", (x, adj)).contiguous() class DyNconv(nn.Module): def forward(self, x, adj): return torch.einsum("ncvl,nvwl->ncwl", (x, adj)).contiguous() class Linear(nn.Module): def __init__(self, c_in, c_out, bias=True): super().__init__() self.mlp = nn.Conv2d(c_in, c_out, kernel_size=1, bias=bias) def forward(self, x): return self.mlp(x) class Prop(nn.Module): def __init__(self, c_in, c_out, gdep, dropout, alpha): super().__init__() self.nconv = NConv() self.mlp = Linear(c_in, c_out) self.gdep, self.dropout, self.alpha = gdep, dropout, alpha def forward(self, x, adj): adj = adj + torch.eye(adj.size(0), device=x.device) d = adj.sum(1) a = adj / d.view(-1, 1) h = x for _ in range(self.gdep): h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a) return self.mlp(h) class MixProp(nn.Module): def __init__(self, c_in, c_out, gdep, dropout, alpha): super().__init__() self.nconv = NConv() self.mlp = Linear((gdep + 1) * c_in, c_out) self.gdep, self.dropout, self.alpha = gdep, dropout, alpha def forward(self, x, adj): adj = adj + torch.eye(adj.size(0), device=x.device) d = adj.sum(1) a = adj / d.view(-1, 1) out = [x] h = x for _ in range(self.gdep): h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a) out.append(h) return self.mlp(torch.cat(out, dim=1)) class DyMixprop(nn.Module): def __init__(self, c_in, c_out, gdep, dropout, alpha): super().__init__() self.nconv = DyNconv() self.mlp1 = Linear((gdep + 1) * c_in, c_out) self.mlp2 = Linear((gdep + 1) * c_in, c_out) self.gdep, self.dropout, self.alpha = gdep, dropout, alpha self.lin1, self.lin2 = Linear(c_in, c_in), Linear(c_in, c_in) def forward(self, x): x1 = torch.tanh(self.lin1(x)) x2 = torch.tanh(self.lin2(x)) adj = self.nconv(x1.transpose(2, 1), x2) adj0 = torch.softmax(adj, dim=2) adj1 = torch.softmax(adj.transpose(2, 1), dim=2) # 两条分支 out1, out2 = [x], [x] h = x for _ in range(self.gdep): h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj0) out1.append(h) h = x for _ in range(self.gdep): h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1) out2.append(h) return self.mlp1(torch.cat(out1, dim=1)) + self.mlp2(torch.cat(out2, dim=1)) class DilatedInception(nn.Module): def __init__(self, cin, cout, dilation_factor=2): super().__init__() self.kernels = [2, 3, 6, 7] cout_each = int(cout / len(self.kernels)) self.convs = nn.ModuleList( [ nn.Conv2d( cin, cout_each, kernel_size=(1, k), dilation=(1, dilation_factor) ) for k in self.kernels ] ) def forward(self, x): outs = [conv(x)[..., -self.convs[-1](x).size(3) :] for conv in self.convs] return torch.cat(outs, dim=1) class GraphConstructor(nn.Module): def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): super().__init__() self.nnodes, self.k, self.dim, self.alpha, self.device = ( nnodes, k, dim, alpha, device, ) self.static_feat = static_feat if static_feat is not None: xd = static_feat.shape[1] self.lin1, self.lin2 = nn.Linear(xd, dim), nn.Linear(xd, dim) else: self.emb1 = nn.Embedding(nnodes, dim) self.emb2 = nn.Embedding(nnodes, dim) self.lin1, self.lin2 = nn.Linear(dim, dim), nn.Linear(dim, dim) def forward(self, idx): if self.static_feat is None: vec1, vec2 = self.emb1(idx), self.emb2(idx) else: vec1 = vec2 = self.static_feat[idx, :] vec1 = torch.tanh(self.alpha * self.lin1(vec1)) vec2 = torch.tanh(self.alpha * self.lin2(vec2)) a = torch.mm(vec1, vec2.transpose(1, 0)) - torch.mm(vec2, vec1.transpose(1, 0)) adj = F.relu(torch.tanh(self.alpha * a)) mask = torch.zeros(idx.size(0), idx.size(0), device=self.device) s1, t1 = adj.topk(self.k, 1) mask.scatter_(1, t1, s1.new_ones(s1.size())) return adj * mask class LayerNorm(nn.Module): __constants__ = ["normalized_shape", "eps", "elementwise_affine"] def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): super().__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) self.normalized_shape, self.eps, self.elementwise_affine = ( tuple(normalized_shape), eps, elementwise_affine, ) if elementwise_affine: self.weight = nn.Parameter(torch.Tensor(*normalized_shape)) self.bias = nn.Parameter(torch.Tensor(*normalized_shape)) init.ones_(self.weight) init.zeros_(self.bias) else: self.register_parameter("weight", None) self.register_parameter("bias", None) def forward(self, x, idx): if self.elementwise_affine: return F.layer_norm( x, tuple(x.shape[1:]), self.weight[:, idx, :], self.bias[:, idx, :], self.eps, ) else: return F.layer_norm(x, tuple(x.shape[1:]), self.weight, self.bias, self.eps) def extra_repr(self): return f"{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}" # --- 合并后的模型类,支持 teacher 与 stmlp 两种分支 --- class STMLP(nn.Module): def __init__(self, args): super().__init__() # 参数从字典中读取 self.adj_mx = get_adj(args) self.num_nodes = args["num_nodes"] self.feature_dim = args["input_dim"] self.input_window = args["input_window"] self.output_window = args["output_window"] self.output_dim = args["output_dim"] self.device = args["device"] self.gcn_true = args["gcn_true"] self.buildA_true = args["buildA_true"] self.gcn_depth = args["gcn_depth"] self.dropout = args["dropout"] self.subgraph_size = args["subgraph_size"] self.node_dim = args["node_dim"] self.dilation_exponential = args["dilation_exponential"] self.conv_channels = args["conv_channels"] self.residual_channels = args["residual_channels"] self.skip_channels = args["skip_channels"] self.end_channels = args["end_channels"] self.layers = args["layers"] self.propalpha = args["propalpha"] self.tanhalpha = args["tanhalpha"] self.layer_norm_affline = args["layer_norm_affline"] self.model_type = args["model_type"] # 'teacher' 或 'stmlp' self.idx = torch.arange(self.num_nodes).to(self.device) self.predefined_A = ( None if self.adj_mx is None else (torch.tensor(self.adj_mx) - torch.eye(self.num_nodes)).to(self.device) ) self.static_feat = None # transformer(保留原有结构) self.encoder_layer = nn.TransformerEncoderLayer( d_model=12, nhead=4, batch_first=True ) self.transformer_encoder = nn.TransformerEncoder( self.encoder_layer, num_layers=3 ) # 构建各层 self.start_conv = nn.Conv2d( self.feature_dim, self.residual_channels, kernel_size=1 ) self.gc = GraphConstructor( self.num_nodes, self.subgraph_size, self.node_dim, self.device, alpha=self.tanhalpha, static_feat=self.static_feat, ) # 计算 receptive_field kernel_size = 7 if self.dilation_exponential > 1: self.receptive_field = int( self.output_dim + (kernel_size - 1) * (self.dilation_exponential**self.layers - 1) / (self.dilation_exponential - 1) ) else: self.receptive_field = self.layers * (kernel_size - 1) + self.output_dim self.filter_convs = nn.ModuleList() self.gate_convs = nn.ModuleList() self.residual_convs = nn.ModuleList() self.skip_convs = nn.ModuleList() self.norm = nn.ModuleList() self.stu_mlp = nn.ModuleList( [ nn.Sequential(nn.Linear(c, c), nn.Linear(c, c), nn.Linear(c, c)) for c in [13, 7, 1] ] ) if self.gcn_true: self.gconv1 = nn.ModuleList() self.gconv2 = nn.ModuleList() new_dilation = 1 for i in range(1): rf_size_i = ( int( 1 + i * (kernel_size - 1) * (self.dilation_exponential**self.layers - 1) / (self.dilation_exponential - 1) ) if self.dilation_exponential > 1 else i * self.layers * (kernel_size - 1) + 1 ) for j in range(1, self.layers + 1): rf_size_j = ( int( rf_size_i + (kernel_size - 1) * (self.dilation_exponential**j - 1) / (self.dilation_exponential - 1) ) if self.dilation_exponential > 1 else rf_size_i + j * (kernel_size - 1) ) self.filter_convs.append( DilatedInception( self.residual_channels, self.conv_channels, dilation_factor=new_dilation, ) ) self.gate_convs.append( DilatedInception( self.residual_channels, self.conv_channels, dilation_factor=new_dilation, ) ) self.residual_convs.append( nn.Conv2d(self.conv_channels, self.residual_channels, kernel_size=1) ) k_size = ( (1, self.input_window - rf_size_j + 1) if self.input_window > self.receptive_field else (1, self.receptive_field - rf_size_j + 1) ) self.skip_convs.append( nn.Conv2d( self.conv_channels, self.skip_channels, kernel_size=k_size ) ) if self.gcn_true: self.gconv1.append( MixProp( self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, self.propalpha, ) ) self.gconv2.append( MixProp( self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, self.propalpha, ) ) norm_size = ( ( self.residual_channels, self.num_nodes, self.input_window - rf_size_j + 1, ) if self.input_window > self.receptive_field else ( self.residual_channels, self.num_nodes, self.receptive_field - rf_size_j + 1, ) ) self.norm.append( LayerNorm(norm_size, elementwise_affine=self.layer_norm_affline) ) new_dilation *= self.dilation_exponential self.end_conv_1 = nn.Conv2d( self.skip_channels, self.end_channels, kernel_size=1, bias=True ) self.end_conv_2 = nn.Conv2d( self.end_channels, self.output_window, kernel_size=1, bias=True ) k0 = ( (1, self.input_window) if self.input_window > self.receptive_field else (1, self.receptive_field) ) self.skip0 = nn.Conv2d( self.feature_dim, self.skip_channels, kernel_size=k0, bias=True ) kE = ( (1, self.input_window - self.receptive_field + 1) if self.input_window > self.receptive_field else (1, 1) ) self.skipE = nn.Conv2d( self.residual_channels, self.skip_channels, kernel_size=kE, bias=True ) # 最后输出分支,根据模型类型选择不同的头 if self.model_type == "teacher": self.tt_linear1 = nn.Linear(self.residual_channels, self.input_window) self.tt_linear2 = nn.Linear(1, 32) self.ss_linear1 = nn.Linear(self.residual_channels, self.input_window) self.ss_linear2 = nn.Linear(1, 32) else: # stmlp self.out_linear1 = nn.Linear(self.residual_channels, self.input_window) self.out_linear2 = nn.Linear(1, 32) def forward(self, source, idx=None): source = source[..., 0:1] sout, tout = [], [] inputs = source.transpose(1, 3) assert inputs.size(3) == self.input_window, "input sequence length mismatch" if self.input_window < self.receptive_field: inputs = F.pad(inputs, (self.receptive_field - self.input_window, 0, 0, 0)) if self.gcn_true: adp = ( self.gc(self.idx if idx is None else idx) if self.buildA_true else self.predefined_A ) x = self.start_conv(inputs) skip = self.skip0(F.dropout(inputs, self.dropout, training=self.training)) for i in range(self.layers): residual = x filters = torch.tanh(self.filter_convs[i](x)) gate = torch.sigmoid(self.gate_convs[i](x)) x = F.dropout(filters * gate, self.dropout, training=self.training) tout.append(x) s = self.skip_convs[i](x) skip = s + skip if self.gcn_true: x = self.gconv1[i](x, adp) + self.gconv2[i](x, adp.transpose(1, 0)) else: x = self.stu_mlp[i](x) x = x + residual[:, :, :, -x.size(3) :] x = self.norm[i](x, self.idx if idx is None else idx) sout.append(x) skip = self.skipE(x) + skip x = F.relu(skip) x = F.relu(self.end_conv_1(x)) x = self.end_conv_2(x) if self.model_type == "teacher": ttout = self.tt_linear2( self.tt_linear1(tout[-1].transpose(1, 3)).transpose(1, 3) ) ssout = self.ss_linear2( self.ss_linear1(sout[-1].transpose(1, 3)).transpose(1, 3) ) return x, ttout, ssout else: x_ = self.out_linear2( self.out_linear1(tout[-1].transpose(1, 3)).transpose(1, 3) ) return x, x_, x