254 lines
9.1 KiB
Python
254 lines
9.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class LinearConv2d(nn.Module):
|
|
def __init__(self, c_in, c_out):
|
|
super().__init__()
|
|
self.mlp = nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True)
|
|
|
|
def forward(self, x):
|
|
return self.mlp(x)
|
|
|
|
|
|
class PointwiseConv2d(nn.Module):
|
|
def __init__(self, c_in, c_out):
|
|
super().__init__()
|
|
self.mlp = nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True)
|
|
|
|
def forward(self, x):
|
|
return self.mlp(x)
|
|
|
|
|
|
class GraphAttention(nn.Module):
|
|
def __init__(self, c_in, c_out, dropout, d=16, emb_length=0, aptonly=False, noapt=False):
|
|
super().__init__()
|
|
self.d = d
|
|
self.aptonly = aptonly
|
|
self.noapt = noapt
|
|
self.mlp = LinearConv2d(c_in * 2, c_out)
|
|
self.dropout = dropout
|
|
self.emb_length = emb_length
|
|
|
|
if aptonly:
|
|
self.qm = PointwiseConv2d(self.emb_length, d)
|
|
self.km = PointwiseConv2d(self.emb_length, d)
|
|
elif noapt:
|
|
self.qm = PointwiseConv2d(c_in, d)
|
|
self.km = PointwiseConv2d(c_in, d)
|
|
else:
|
|
self.qm = PointwiseConv2d(c_in + self.emb_length, d)
|
|
self.km = PointwiseConv2d(c_in + self.emb_length, d)
|
|
|
|
def forward(self, x, embedding):
|
|
# x: [B, C, N, T]
|
|
# embedding: [emb_length, N] (as parameter), we broadcast to [B, emb_length, N, T]
|
|
out = [x]
|
|
|
|
embedding = embedding.repeat((x.shape[0], x.shape[-1], 1, 1)) # [B, T, emb, N]
|
|
embedding = embedding.permute(0, 2, 3, 1).contiguous() # [B, emb, N, T]
|
|
|
|
if self.aptonly:
|
|
x_embedding = embedding
|
|
query = self.qm(x_embedding).permute(0, 3, 2, 1) # [B, T, N, d]
|
|
key = self.km(x_embedding).permute(0, 3, 2, 1) # [B, T, N, d]
|
|
attention = torch.matmul(query, key.permute(0, 1, 3, 2)) # [B, T, N, N]
|
|
attention = attention / (self.d ** 0.5)
|
|
attention = F.softmax(attention, dim=-1)
|
|
elif self.noapt:
|
|
x_embedding = x
|
|
query = self.qm(x_embedding).permute(0, 3, 2, 1) # [B, T, N, d]
|
|
key = self.km(x_embedding).permute(0, 3, 2, 1) # [B, T, N, d]
|
|
attention = torch.matmul(query, key.permute(0, 1, 3, 2)) # [B, T, N, N]
|
|
attention = attention / (self.d ** 0.5)
|
|
attention = F.softmax(attention, dim=-1)
|
|
else:
|
|
x_embedding = torch.cat([x, embedding], dim=1) # [B, C+emb, N, T]
|
|
query = self.qm(x_embedding).permute(0, 3, 2, 1) # [B, T, N, d]
|
|
key = self.km(x_embedding).permute(0, 3, 2, 1) # [B, T, N, d]
|
|
attention = torch.matmul(query, key.permute(0, 1, 3, 2)) # [B, T, N, N]
|
|
attention = attention / (self.d ** 0.5)
|
|
attention = F.softmax(attention, dim=-1)
|
|
|
|
# apply attention over nodes: [B, C, N, T] -> [B, T, C, N] * [B, T, N, N] -> [B, T, C, N]
|
|
x = torch.matmul(x.permute(0, 3, 1, 2), attention).permute(0, 2, 3, 1)
|
|
out.append(x)
|
|
|
|
h = torch.cat(out, dim=1)
|
|
h = self.mlp(h)
|
|
h = F.dropout(h, self.dropout, training=self.training)
|
|
return h
|
|
|
|
|
|
class STAWnetCore(nn.Module):
|
|
def __init__(
|
|
self,
|
|
device,
|
|
num_nodes,
|
|
dropout=0.3,
|
|
gat_bool=True,
|
|
addaptadj=True,
|
|
aptonly=False,
|
|
noapt=False,
|
|
in_dim=2,
|
|
out_dim=12,
|
|
residual_channels=32,
|
|
dilation_channels=32,
|
|
skip_channels=256,
|
|
end_channels=512,
|
|
kernel_size=2,
|
|
blocks=4,
|
|
layers=2,
|
|
emb_length=16,
|
|
):
|
|
super().__init__()
|
|
|
|
self.dropout = dropout
|
|
self.blocks = blocks
|
|
self.layers = layers
|
|
self.gat_bool = gat_bool
|
|
self.aptonly = aptonly
|
|
self.noapt = noapt
|
|
self.addaptadj = addaptadj
|
|
self.emb_length = emb_length
|
|
|
|
self.filter_convs = nn.ModuleList()
|
|
self.gate_convs = nn.ModuleList()
|
|
self.residual_convs = nn.ModuleList()
|
|
self.skip_convs = nn.ModuleList()
|
|
self.bn = nn.ModuleList()
|
|
self.gat = nn.ModuleList()
|
|
|
|
self.start_conv = nn.Conv2d(in_channels=in_dim, out_channels=residual_channels, kernel_size=(1, 1))
|
|
|
|
self.supports = None
|
|
receptive_field = 1
|
|
|
|
if gat_bool and addaptadj:
|
|
# learnable node embeddings: [emb_length, N]
|
|
self.embedding = nn.Parameter(torch.randn(self.emb_length, num_nodes, device=device), requires_grad=True)
|
|
|
|
for _ in range(blocks):
|
|
additional_scope = kernel_size - 1
|
|
new_dilation = 1
|
|
for _ in range(layers):
|
|
# dilated temporal convs
|
|
self.filter_convs.append(
|
|
nn.Conv2d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1, kernel_size), dilation=new_dilation)
|
|
)
|
|
self.gate_convs.append(
|
|
nn.Conv2d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1, kernel_size), dilation=new_dilation)
|
|
)
|
|
|
|
# 1x1 residual/skip
|
|
self.residual_convs.append(nn.Conv2d(in_channels=dilation_channels, out_channels=residual_channels, kernel_size=(1, 1)))
|
|
self.skip_convs.append(nn.Conv2d(in_channels=dilation_channels, out_channels=skip_channels, kernel_size=(1, 1)))
|
|
|
|
self.bn.append(nn.BatchNorm2d(residual_channels))
|
|
|
|
new_dilation *= 2
|
|
receptive_field += additional_scope
|
|
additional_scope *= 2
|
|
|
|
if self.gat_bool:
|
|
self.gat.append(GraphAttention(dilation_channels, residual_channels, dropout, emb_length=emb_length, aptonly=aptonly, noapt=noapt))
|
|
|
|
self.end_conv_1 = nn.Conv2d(in_channels=skip_channels, out_channels=end_channels, kernel_size=(1, 1), bias=True)
|
|
self.end_conv_2 = nn.Conv2d(in_channels=end_channels, out_channels=out_dim, kernel_size=(1, 1), bias=True)
|
|
|
|
self.receptive_field = receptive_field
|
|
|
|
def forward(self, input):
|
|
# input: [B, C_in, N, T]
|
|
in_len = input.size(3)
|
|
if in_len < self.receptive_field:
|
|
x = F.pad(input, (self.receptive_field - in_len, 0, 0, 0))
|
|
else:
|
|
x = input
|
|
|
|
x = self.start_conv(x)
|
|
skip = 0
|
|
|
|
for i in range(self.blocks * self.layers):
|
|
residual = x
|
|
# gated temporal conv
|
|
filt = torch.tanh(self.filter_convs[i](residual))
|
|
gate = torch.sigmoid(self.gate_convs[i](residual))
|
|
x = filt * gate
|
|
|
|
# skip connection accumulation (align time length)
|
|
s = self.skip_convs[i](x)
|
|
if isinstance(skip, torch.Tensor):
|
|
skip = skip[:, :, :, -s.size(3):]
|
|
else:
|
|
skip = 0
|
|
skip = s + skip
|
|
|
|
# spatial attention or residual conv
|
|
if self.gat_bool and hasattr(self, 'embedding'):
|
|
x = self.gat[i](x, self.embedding)
|
|
else:
|
|
x = self.residual_convs[i](x)
|
|
|
|
# residual connection and BN
|
|
x = x + residual[:, :, :, -x.size(3):]
|
|
x = self.bn[i](x)
|
|
|
|
x = F.relu(skip)
|
|
x = F.relu(self.end_conv_1(x))
|
|
x = self.end_conv_2(x)
|
|
# shape: [B, horizon(out_dim), N, T_reduced(usually 1)]
|
|
return x
|
|
|
|
|
|
class STAWnet(nn.Module):
|
|
"""
|
|
Project-adapted STAWnet wrapper that matches the common interface:
|
|
- Input: [B, T, N, C_total]
|
|
- Output: [B, horizon, N, output_dim]
|
|
"""
|
|
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.args = args
|
|
|
|
device = args.get('device', 'cpu')
|
|
num_nodes = args['num_nodes']
|
|
|
|
# Model IO configs
|
|
in_dim = args.get('in_dim', 2) # how many covariates to feed into the model
|
|
horizon = args.get('horizon', 12)
|
|
output_dim = args.get('output_dim', 1)
|
|
|
|
self.use_channels = in_dim
|
|
|
|
self.core = STAWnetCore(
|
|
device=device,
|
|
num_nodes=num_nodes,
|
|
dropout=args.get('dropout', 0.3),
|
|
gat_bool=args.get('gat_bool', True),
|
|
addaptadj=args.get('addaptadj', True),
|
|
aptonly=args.get('aptonly', False),
|
|
noapt=args.get('noapt', False),
|
|
in_dim=in_dim,
|
|
out_dim=horizon, # channels represent horizon steps
|
|
residual_channels=args.get('residual_channels', 32),
|
|
dilation_channels=args.get('dilation_channels', 32),
|
|
skip_channels=args.get('skip_channels', 256),
|
|
end_channels=args.get('end_channels', 512),
|
|
kernel_size=args.get('kernel_size', 2),
|
|
blocks=args.get('blocks', 4),
|
|
layers=args.get('layers', 2),
|
|
emb_length=args.get('emb_length', 16),
|
|
)
|
|
|
|
def forward(self, x):
|
|
# x: [B, T, N, C_total] -> pick first self.use_channels, then to [B, C, N, T]
|
|
x = x[..., :self.use_channels].transpose(1, 3) # [B, C, N, T]
|
|
y = self.core(x) # [B, horizon, N, T_reduced(=1)]
|
|
# Keep horizon on channel dimension to match project convention (like GWN)
|
|
return y
|
|
|
|
|