import torch import torch.nn as nn import torch.nn.functional as F class LinearConv2d(nn.Module): def __init__(self, c_in, c_out): super().__init__() self.mlp = nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True) def forward(self, x): return self.mlp(x) class PointwiseConv2d(nn.Module): def __init__(self, c_in, c_out): super().__init__() self.mlp = nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True) def forward(self, x): return self.mlp(x) class GraphAttention(nn.Module): def __init__(self, c_in, c_out, dropout, d=16, emb_length=0, aptonly=False, noapt=False): super().__init__() self.d = d self.aptonly = aptonly self.noapt = noapt self.mlp = LinearConv2d(c_in * 2, c_out) self.dropout = dropout self.emb_length = emb_length if aptonly: self.qm = PointwiseConv2d(self.emb_length, d) self.km = PointwiseConv2d(self.emb_length, d) elif noapt: self.qm = PointwiseConv2d(c_in, d) self.km = PointwiseConv2d(c_in, d) else: self.qm = PointwiseConv2d(c_in + self.emb_length, d) self.km = PointwiseConv2d(c_in + self.emb_length, d) def forward(self, x, embedding): # x: [B, C, N, T] # embedding: [emb_length, N] (as parameter), we broadcast to [B, emb_length, N, T] out = [x] embedding = embedding.repeat((x.shape[0], x.shape[-1], 1, 1)) # [B, T, emb, N] embedding = embedding.permute(0, 2, 3, 1).contiguous() # [B, emb, N, T] if self.aptonly: x_embedding = embedding query = self.qm(x_embedding).permute(0, 3, 2, 1) # [B, T, N, d] key = self.km(x_embedding).permute(0, 3, 2, 1) # [B, T, N, d] attention = torch.matmul(query, key.permute(0, 1, 3, 2)) # [B, T, N, N] attention = attention / (self.d ** 0.5) attention = F.softmax(attention, dim=-1) elif self.noapt: x_embedding = x query = self.qm(x_embedding).permute(0, 3, 2, 1) # [B, T, N, d] key = self.km(x_embedding).permute(0, 3, 2, 1) # [B, T, N, d] attention = torch.matmul(query, key.permute(0, 1, 3, 2)) # [B, T, N, N] attention = attention / (self.d ** 0.5) attention = F.softmax(attention, dim=-1) else: x_embedding = torch.cat([x, embedding], dim=1) # [B, C+emb, N, T] query = self.qm(x_embedding).permute(0, 3, 2, 1) # [B, T, N, d] key = self.km(x_embedding).permute(0, 3, 2, 1) # [B, T, N, d] attention = torch.matmul(query, key.permute(0, 1, 3, 2)) # [B, T, N, N] attention = attention / (self.d ** 0.5) attention = F.softmax(attention, dim=-1) # apply attention over nodes: [B, C, N, T] -> [B, T, C, N] * [B, T, N, N] -> [B, T, C, N] x = torch.matmul(x.permute(0, 3, 1, 2), attention).permute(0, 2, 3, 1) out.append(x) h = torch.cat(out, dim=1) h = self.mlp(h) h = F.dropout(h, self.dropout, training=self.training) return h class STAWnetCore(nn.Module): def __init__( self, device, num_nodes, dropout=0.3, gat_bool=True, addaptadj=True, aptonly=False, noapt=False, in_dim=2, out_dim=12, residual_channels=32, dilation_channels=32, skip_channels=256, end_channels=512, kernel_size=2, blocks=4, layers=2, emb_length=16, ): super().__init__() self.dropout = dropout self.blocks = blocks self.layers = layers self.gat_bool = gat_bool self.aptonly = aptonly self.noapt = noapt self.addaptadj = addaptadj self.emb_length = emb_length self.filter_convs = nn.ModuleList() self.gate_convs = nn.ModuleList() self.residual_convs = nn.ModuleList() self.skip_convs = nn.ModuleList() self.bn = nn.ModuleList() self.gat = nn.ModuleList() self.start_conv = nn.Conv2d(in_channels=in_dim, out_channels=residual_channels, kernel_size=(1, 1)) self.supports = None receptive_field = 1 if gat_bool and addaptadj: # learnable node embeddings: [emb_length, N] self.embedding = nn.Parameter(torch.randn(self.emb_length, num_nodes, device=device), requires_grad=True) for _ in range(blocks): additional_scope = kernel_size - 1 new_dilation = 1 for _ in range(layers): # dilated temporal convs self.filter_convs.append( nn.Conv2d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1, kernel_size), dilation=new_dilation) ) self.gate_convs.append( nn.Conv2d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1, kernel_size), dilation=new_dilation) ) # 1x1 residual/skip self.residual_convs.append(nn.Conv2d(in_channels=dilation_channels, out_channels=residual_channels, kernel_size=(1, 1))) self.skip_convs.append(nn.Conv2d(in_channels=dilation_channels, out_channels=skip_channels, kernel_size=(1, 1))) self.bn.append(nn.BatchNorm2d(residual_channels)) new_dilation *= 2 receptive_field += additional_scope additional_scope *= 2 if self.gat_bool: self.gat.append(GraphAttention(dilation_channels, residual_channels, dropout, emb_length=emb_length, aptonly=aptonly, noapt=noapt)) self.end_conv_1 = nn.Conv2d(in_channels=skip_channels, out_channels=end_channels, kernel_size=(1, 1), bias=True) self.end_conv_2 = nn.Conv2d(in_channels=end_channels, out_channels=out_dim, kernel_size=(1, 1), bias=True) self.receptive_field = receptive_field def forward(self, input): # input: [B, C_in, N, T] in_len = input.size(3) if in_len < self.receptive_field: x = F.pad(input, (self.receptive_field - in_len, 0, 0, 0)) else: x = input x = self.start_conv(x) skip = 0 for i in range(self.blocks * self.layers): residual = x # gated temporal conv filt = torch.tanh(self.filter_convs[i](residual)) gate = torch.sigmoid(self.gate_convs[i](residual)) x = filt * gate # skip connection accumulation (align time length) s = self.skip_convs[i](x) if isinstance(skip, torch.Tensor): skip = skip[:, :, :, -s.size(3):] else: skip = 0 skip = s + skip # spatial attention or residual conv if self.gat_bool and hasattr(self, 'embedding'): x = self.gat[i](x, self.embedding) else: x = self.residual_convs[i](x) # residual connection and BN x = x + residual[:, :, :, -x.size(3):] x = self.bn[i](x) x = F.relu(skip) x = F.relu(self.end_conv_1(x)) x = self.end_conv_2(x) # shape: [B, horizon(out_dim), N, T_reduced(usually 1)] return x class STAWnet(nn.Module): """ Project-adapted STAWnet wrapper that matches the common interface: - Input: [B, T, N, C_total] - Output: [B, horizon, N, output_dim] """ def __init__(self, args): super().__init__() self.args = args device = args.get('device', 'cpu') num_nodes = args['num_nodes'] # Model IO configs in_dim = args.get('in_dim', 2) # how many covariates to feed into the model horizon = args.get('horizon', 12) output_dim = args.get('output_dim', 1) self.use_channels = in_dim self.core = STAWnetCore( device=device, num_nodes=num_nodes, dropout=args.get('dropout', 0.3), gat_bool=args.get('gat_bool', True), addaptadj=args.get('addaptadj', True), aptonly=args.get('aptonly', False), noapt=args.get('noapt', False), in_dim=in_dim, out_dim=horizon, # channels represent horizon steps residual_channels=args.get('residual_channels', 32), dilation_channels=args.get('dilation_channels', 32), skip_channels=args.get('skip_channels', 256), end_channels=args.get('end_channels', 512), kernel_size=args.get('kernel_size', 2), blocks=args.get('blocks', 4), layers=args.get('layers', 2), emb_length=args.get('emb_length', 16), ) def forward(self, x): # x: [B, T, N, C_total] -> pick first self.use_channels, then to [B, C, N, T] x = x[..., :self.use_channels].transpose(1, 3) # [B, C, N, T] y = self.core(x) # [B, horizon, N, T_reduced(=1)] # Keep horizon on channel dimension to match project convention (like GWN) return y