153 lines
6.0 KiB
Python
153 lines
6.0 KiB
Python
import torch
|
|
import math
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from models.STGODE.odegcn import ODEG
|
|
from models.STGODE.adj import get_A_hat
|
|
|
|
|
|
class Chomp1d(nn.Module):
|
|
def __init__(self, chomp_size):
|
|
super(Chomp1d, self).__init__()
|
|
self.chomp_size = chomp_size
|
|
|
|
def forward(self, x):
|
|
return x[:, :, :, :-self.chomp_size].contiguous()
|
|
|
|
|
|
class TemporalConvNet(nn.Module):
|
|
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
|
|
super(TemporalConvNet, self).__init__()
|
|
layers = []
|
|
num_levels = len(num_channels)
|
|
for i in range(num_levels):
|
|
dilation_size = 2 ** i
|
|
in_channels = num_inputs if i == 0 else num_channels[i - 1]
|
|
out_channels = num_channels[i]
|
|
padding = (kernel_size - 1) * dilation_size
|
|
self.conv = nn.Conv2d(in_channels, out_channels, (1, kernel_size), dilation=(1, dilation_size),
|
|
padding=(0, padding))
|
|
self.conv.weight.data.normal_(0, 0.01)
|
|
self.chomp = Chomp1d(padding)
|
|
self.relu = nn.ReLU()
|
|
self.dropout = nn.Dropout(dropout)
|
|
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
|
|
|
|
self.network = nn.Sequential(*layers)
|
|
self.downsample = nn.Conv2d(num_inputs, num_channels[-1], (1, 1)) if num_inputs != num_channels[-1] else None
|
|
if self.downsample:
|
|
self.downsample.weight.data.normal_(0, 0.01)
|
|
|
|
def forward(self, x):
|
|
y = x.permute(0, 3, 1, 2)
|
|
y = F.relu(self.network(y) + self.downsample(y) if self.downsample else y)
|
|
y = y.permute(0, 2, 3, 1)
|
|
return y
|
|
|
|
|
|
class STGCNBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels, num_nodes, A_hat):
|
|
super(STGCNBlock, self).__init__()
|
|
self.A_hat = A_hat
|
|
self.temporal1 = TemporalConvNet(num_inputs=in_channels, num_channels=out_channels)
|
|
self.odeg = ODEG(out_channels[-1], 12, A_hat, time=6)
|
|
self.temporal2 = TemporalConvNet(num_inputs=out_channels[-1], num_channels=out_channels)
|
|
self.batch_norm = nn.BatchNorm2d(num_nodes)
|
|
|
|
def forward(self, X):
|
|
t = self.temporal1(X)
|
|
t = self.odeg(t)
|
|
t = self.temporal2(F.relu(t))
|
|
return self.batch_norm(t)
|
|
|
|
|
|
class GPT2Backbone(nn.Module):
|
|
def __init__(self, hidden_size: int, n_layer: int = 4, n_head: int = 4, n_embd: int | None = None, use_pretrained: bool = True):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.use_transformers = False
|
|
self.model = None
|
|
if n_embd is None:
|
|
n_embd = hidden_size
|
|
if use_pretrained:
|
|
try:
|
|
from transformers import GPT2Model, GPT2Config
|
|
config = GPT2Config(n_embd=n_embd, n_layer=n_layer, n_head=n_head, n_positions=1024, n_ctx=1024, vocab_size=1)
|
|
self.model = GPT2Model(config)
|
|
self.use_transformers = True
|
|
except Exception:
|
|
self.use_transformers = False
|
|
if not self.use_transformers:
|
|
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=n_head, batch_first=True)
|
|
self.model = nn.TransformerEncoder(encoder_layer, num_layers=n_layer)
|
|
|
|
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
|
|
if self.use_transformers:
|
|
outputs = self.model(inputs_embeds=inputs_embeds)
|
|
return outputs.last_hidden_state
|
|
else:
|
|
return self.model(inputs_embeds)
|
|
|
|
|
|
class ODEGCN_LLM(nn.Module):
|
|
def __init__(self, config):
|
|
super(ODEGCN_LLM, self).__init__()
|
|
args = config['model']
|
|
num_nodes = config['data']['num_nodes']
|
|
num_features = args['num_features']
|
|
num_timesteps_input = args['history']
|
|
num_timesteps_output = args['horizon']
|
|
A_sp_hat, A_se_hat = get_A_hat(config)
|
|
|
|
self.sp_blocks = nn.ModuleList(
|
|
[nn.Sequential(
|
|
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat),
|
|
STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3)
|
|
])
|
|
|
|
self.se_blocks = nn.ModuleList(
|
|
[nn.Sequential(
|
|
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat),
|
|
STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3)
|
|
])
|
|
|
|
self.history = num_timesteps_input
|
|
self.horizon = num_timesteps_output
|
|
|
|
hidden_size = int(args.get('llm_hidden', 128))
|
|
llm_layers = int(args.get('llm_layers', 4))
|
|
llm_heads = int(args.get('llm_heads', 4))
|
|
use_pretrained = bool(args.get('llm_pretrained', True))
|
|
|
|
self.to_llm_embed = nn.Linear(64, hidden_size)
|
|
self.gpt2 = GPT2Backbone(hidden_size=hidden_size, n_layer=llm_layers, n_head=llm_heads, use_pretrained=use_pretrained)
|
|
self.proj_head = nn.Sequential(
|
|
nn.Linear(hidden_size, hidden_size),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_size, self.horizon)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x[..., 0:1].permute(0, 2, 1, 3)
|
|
outs = []
|
|
for blk in self.sp_blocks:
|
|
outs.append(blk(x))
|
|
for blk in self.se_blocks:
|
|
outs.append(blk(x))
|
|
outs = torch.stack(outs)
|
|
x = torch.max(outs, dim=0)[0]
|
|
|
|
# x: (B, N, T, 64) physical quantities after ODE-based transform
|
|
B, N, T, C = x.shape
|
|
x = self.to_llm_embed(x) # (B, N, T, H)
|
|
x = x.permute(0, 1, 2, 3).contiguous().view(B * N, T, -1) # (B*N, T, H)
|
|
|
|
llm_hidden = self.gpt2(inputs_embeds=x) # (B*N, T, H)
|
|
last_state = llm_hidden[:, -1, :] # (B*N, H)
|
|
y = self.proj_head(last_state) # (B*N, horizon)
|
|
y = y.view(B, N, self.horizon).permute(0, 2, 1).unsqueeze(-1) # (B, horizon, N, 1)
|
|
return y
|
|
|
|
|