import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init from data.get_adj import get_adj import numbers # --- 基础算子 --- class NConv(nn.Module): def forward(self, x, adj): return torch.einsum('ncwl,vw->ncvl', (x, adj)).contiguous() class DyNconv(nn.Module): def forward(self, x, adj): return torch.einsum('ncvl,nvwl->ncwl', (x, adj)).contiguous() class Linear(nn.Module): def __init__(self, c_in, c_out, bias=True): super().__init__() self.mlp = nn.Conv2d(c_in, c_out, kernel_size=1, bias=bias) def forward(self, x): return self.mlp(x) class Prop(nn.Module): def __init__(self, c_in, c_out, gdep, dropout, alpha): super().__init__() self.nconv = NConv() self.mlp = Linear(c_in, c_out) self.gdep, self.dropout, self.alpha = gdep, dropout, alpha def forward(self, x, adj): adj = adj + torch.eye(adj.size(0), device=x.device) d = adj.sum(1) a = adj / d.view(-1, 1) h = x for _ in range(self.gdep): h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a) return self.mlp(h) class MixProp(nn.Module): def __init__(self, c_in, c_out, gdep, dropout, alpha): super().__init__() self.nconv = NConv() self.mlp = Linear((gdep + 1) * c_in, c_out) self.gdep, self.dropout, self.alpha = gdep, dropout, alpha def forward(self, x, adj): adj = adj + torch.eye(adj.size(0), device=x.device) d = adj.sum(1) a = adj / d.view(-1, 1) out = [x] h = x for _ in range(self.gdep): h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a) out.append(h) return self.mlp(torch.cat(out, dim=1)) class DyMixprop(nn.Module): def __init__(self, c_in, c_out, gdep, dropout, alpha): super().__init__() self.nconv = DyNconv() self.mlp1 = Linear((gdep + 1) * c_in, c_out) self.mlp2 = Linear((gdep + 1) * c_in, c_out) self.gdep, self.dropout, self.alpha = gdep, dropout, alpha self.lin1, self.lin2 = Linear(c_in, c_in), Linear(c_in, c_in) def forward(self, x): x1 = torch.tanh(self.lin1(x)) x2 = torch.tanh(self.lin2(x)) adj = self.nconv(x1.transpose(2, 1), x2) adj0 = torch.softmax(adj, dim=2) adj1 = torch.softmax(adj.transpose(2, 1), dim=2) # 两条分支 out1, out2 = [x], [x] h = x for _ in range(self.gdep): h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj0) out1.append(h) h = x for _ in range(self.gdep): h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1) out2.append(h) return self.mlp1(torch.cat(out1, dim=1)) + self.mlp2(torch.cat(out2, dim=1)) class DilatedInception(nn.Module): def __init__(self, cin, cout, dilation_factor=2): super().__init__() self.kernels = [2, 3, 6, 7] cout_each = int(cout / len(self.kernels)) self.convs = nn.ModuleList([nn.Conv2d(cin, cout_each, kernel_size=(1, k), dilation=(1, dilation_factor)) for k in self.kernels]) def forward(self, x): outs = [conv(x)[..., -self.convs[-1](x).size(3):] for conv in self.convs] return torch.cat(outs, dim=1) class GraphConstructor(nn.Module): def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): super().__init__() self.nnodes, self.k, self.dim, self.alpha, self.device = nnodes, k, dim, alpha, device self.static_feat = static_feat if static_feat is not None: xd = static_feat.shape[1] self.lin1, self.lin2 = nn.Linear(xd, dim), nn.Linear(xd, dim) else: self.emb1 = nn.Embedding(nnodes, dim) self.emb2 = nn.Embedding(nnodes, dim) self.lin1, self.lin2 = nn.Linear(dim, dim), nn.Linear(dim, dim) def forward(self, idx): if self.static_feat is None: vec1, vec2 = self.emb1(idx), self.emb2(idx) else: vec1 = vec2 = self.static_feat[idx, :] vec1 = torch.tanh(self.alpha * self.lin1(vec1)) vec2 = torch.tanh(self.alpha * self.lin2(vec2)) a = torch.mm(vec1, vec2.transpose(1, 0)) - torch.mm(vec2, vec1.transpose(1, 0)) adj = F.relu(torch.tanh(self.alpha * a)) mask = torch.zeros(idx.size(0), idx.size(0), device=self.device) s1, t1 = adj.topk(self.k, 1) mask.scatter_(1, t1, s1.new_ones(s1.size())) return adj * mask class LayerNorm(nn.Module): __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): super().__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) self.normalized_shape, self.eps, self.elementwise_affine = tuple(normalized_shape), eps, elementwise_affine if elementwise_affine: self.weight = nn.Parameter(torch.Tensor(*normalized_shape)) self.bias = nn.Parameter(torch.Tensor(*normalized_shape)) init.ones_(self.weight); init.zeros_(self.bias) else: self.register_parameter('weight', None) self.register_parameter('bias', None) def forward(self, x, idx): if self.elementwise_affine: return F.layer_norm(x, tuple(x.shape[1:]), self.weight[:, idx, :], self.bias[:, idx, :], self.eps) else: return F.layer_norm(x, tuple(x.shape[1:]), self.weight, self.bias, self.eps) def extra_repr(self): return f'{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}' # --- 合并后的模型类,支持 teacher 与 stmlp 两种分支 --- class STMLP(nn.Module): def __init__(self, args): super().__init__() # 参数从字典中读取 self.adj_mx = get_adj(args) self.num_nodes = args['num_nodes'] self.feature_dim = args['input_dim'] self.input_window = args['input_window'] self.output_window = args['output_window'] self.output_dim = args['output_dim'] self.device = args['device'] self.gcn_true = args['gcn_true'] self.buildA_true = args['buildA_true'] self.gcn_depth = args['gcn_depth'] self.dropout = args['dropout'] self.subgraph_size = args['subgraph_size'] self.node_dim = args['node_dim'] self.dilation_exponential = args['dilation_exponential'] self.conv_channels = args['conv_channels'] self.residual_channels = args['residual_channels'] self.skip_channels = args['skip_channels'] self.end_channels = args['end_channels'] self.layers = args['layers'] self.propalpha = args['propalpha'] self.tanhalpha = args['tanhalpha'] self.layer_norm_affline = args['layer_norm_affline'] self.model_type = args['model_type'] # 'teacher' 或 'stmlp' self.idx = torch.arange(self.num_nodes).to(self.device) self.predefined_A = None if self.adj_mx is None else (torch.tensor(self.adj_mx) - torch.eye(self.num_nodes)).to( self.device) self.static_feat = None # transformer(保留原有结构) self.encoder_layer = nn.TransformerEncoderLayer(d_model=12, nhead=4, batch_first=True) self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=3) # 构建各层 self.start_conv = nn.Conv2d(self.feature_dim, self.residual_channels, kernel_size=1) self.gc = GraphConstructor(self.num_nodes, self.subgraph_size, self.node_dim, self.device, alpha=self.tanhalpha, static_feat=self.static_feat) # 计算 receptive_field kernel_size = 7 if self.dilation_exponential > 1: self.receptive_field = int( self.output_dim + (kernel_size - 1) * (self.dilation_exponential ** self.layers - 1) / ( self.dilation_exponential - 1)) else: self.receptive_field = self.layers * (kernel_size - 1) + self.output_dim self.filter_convs = nn.ModuleList() self.gate_convs = nn.ModuleList() self.residual_convs = nn.ModuleList() self.skip_convs = nn.ModuleList() self.norm = nn.ModuleList() self.stu_mlp = nn.ModuleList([nn.Sequential(nn.Linear(c, c), nn.Linear(c, c), nn.Linear(c, c)) for c in [13, 7, 1]]) if self.gcn_true: self.gconv1 = nn.ModuleList() self.gconv2 = nn.ModuleList() new_dilation = 1 for i in range(1): rf_size_i = int(1 + i * (kernel_size - 1) * (self.dilation_exponential ** self.layers - 1) / ( self.dilation_exponential - 1)) if self.dilation_exponential > 1 else i * self.layers * ( kernel_size - 1) + 1 for j in range(1, self.layers + 1): rf_size_j = int(rf_size_i + (kernel_size - 1) * (self.dilation_exponential ** j - 1) / ( self.dilation_exponential - 1)) if self.dilation_exponential > 1 else rf_size_i + j * ( kernel_size - 1) self.filter_convs.append( DilatedInception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation)) self.gate_convs.append( DilatedInception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation)) self.residual_convs.append(nn.Conv2d(self.conv_channels, self.residual_channels, kernel_size=1)) k_size = (1, self.input_window - rf_size_j + 1) if self.input_window > self.receptive_field else ( 1, self.receptive_field - rf_size_j + 1) self.skip_convs.append(nn.Conv2d(self.conv_channels, self.skip_channels, kernel_size=k_size)) if self.gcn_true: self.gconv1.append(MixProp(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, self.propalpha)) self.gconv2.append(MixProp(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, self.propalpha)) norm_size = (self.residual_channels, self.num_nodes, self.input_window - rf_size_j + 1) if self.input_window > self.receptive_field else ( self.residual_channels, self.num_nodes, self.receptive_field - rf_size_j + 1) self.norm.append(LayerNorm(norm_size, elementwise_affine=self.layer_norm_affline)) new_dilation *= self.dilation_exponential self.end_conv_1 = nn.Conv2d(self.skip_channels, self.end_channels, kernel_size=1, bias=True) self.end_conv_2 = nn.Conv2d(self.end_channels, self.output_window, kernel_size=1, bias=True) k0 = (1, self.input_window) if self.input_window > self.receptive_field else (1, self.receptive_field) self.skip0 = nn.Conv2d(self.feature_dim, self.skip_channels, kernel_size=k0, bias=True) kE = (1, self.input_window - self.receptive_field + 1) if self.input_window > self.receptive_field else (1, 1) self.skipE = nn.Conv2d(self.residual_channels, self.skip_channels, kernel_size=kE, bias=True) # 最后输出分支,根据模型类型选择不同的头 if self.model_type == 'teacher': self.tt_linear1 = nn.Linear(self.residual_channels, self.input_window) self.tt_linear2 = nn.Linear(1, 32) self.ss_linear1 = nn.Linear(self.residual_channels, self.input_window) self.ss_linear2 = nn.Linear(1, 32) else: # stmlp self.out_linear1 = nn.Linear(self.residual_channels, self.input_window) self.out_linear2 = nn.Linear(1, 32) def forward(self, source, idx=None): source = source[..., 0:1] sout, tout = [], [] inputs = source.transpose(1, 3) assert inputs.size(3) == self.input_window, 'input sequence length mismatch' if self.input_window < self.receptive_field: inputs = F.pad(inputs, (self.receptive_field - self.input_window, 0, 0, 0)) if self.gcn_true: adp = self.gc(self.idx if idx is None else idx) if self.buildA_true else self.predefined_A x = self.start_conv(inputs) skip = self.skip0(F.dropout(inputs, self.dropout, training=self.training)) for i in range(self.layers): residual = x filters = torch.tanh(self.filter_convs[i](x)) gate = torch.sigmoid(self.gate_convs[i](x)) x = F.dropout(filters * gate, self.dropout, training=self.training) tout.append(x) s = self.skip_convs[i](x) skip = s + skip if self.gcn_true: x = self.gconv1[i](x, adp) + self.gconv2[i](x, adp.transpose(1, 0)) else: x = self.stu_mlp[i](x) x = x + residual[:, :, :, -x.size(3):] x = self.norm[i](x, self.idx if idx is None else idx) sout.append(x) skip = self.skipE(x) + skip x = F.relu(skip) x = F.relu(self.end_conv_1(x)) x = self.end_conv_2(x) if self.model_type == 'teacher': ttout = self.tt_linear2(self.tt_linear1(tout[-1].transpose(1, 3)).transpose(1, 3)) ssout = self.ss_linear2(self.ss_linear1(sout[-1].transpose(1, 3)).transpose(1, 3)) return x, ttout, ssout else: x_ = self.out_linear2(self.out_linear1(tout[-1].transpose(1, 3)).transpose(1, 3)) return x, x_, x