import torch import torch.nn as nn import torch.nn.functional as F from models.STGODE.odegcn import ODEG from models.STGODE.adj import get_A_hat class Chomp1d(nn.Module): def __init__(self, chomp_size): super(Chomp1d, self).__init__() self.chomp_size = chomp_size def forward(self, x): return x[:, :, :, :-self.chomp_size].contiguous() class TemporalConvNet(nn.Module): def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2): super(TemporalConvNet, self).__init__() layers = [] num_levels = len(num_channels) for i in range(num_levels): dilation_size = 2 ** i in_channels = num_inputs if i == 0 else num_channels[i - 1] out_channels = num_channels[i] padding = (kernel_size - 1) * dilation_size self.conv = nn.Conv2d(in_channels, out_channels, (1, kernel_size), dilation=(1, dilation_size), padding=(0, padding)) self.conv.weight.data.normal_(0, 0.01) self.chomp = Chomp1d(padding) self.relu = nn.ReLU() self.dropout = nn.Dropout(dropout) layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)] self.network = nn.Sequential(*layers) self.downsample = nn.Conv2d(num_inputs, num_channels[-1], (1, 1)) if num_inputs != num_channels[-1] else None if self.downsample: self.downsample.weight.data.normal_(0, 0.01) def forward(self, x): y = x.permute(0, 3, 1, 2) y = F.relu(self.network(y) + self.downsample(y) if self.downsample else y) y = y.permute(0, 2, 3, 1) return y class STGCNBlock(nn.Module): def __init__(self, in_channels, out_channels, num_nodes, A_hat): super(STGCNBlock, self).__init__() self.A_hat = A_hat self.temporal1 = TemporalConvNet(num_inputs=in_channels, num_channels=out_channels) self.odeg = ODEG(out_channels[-1], 12, A_hat, time=6) self.temporal2 = TemporalConvNet(num_inputs=out_channels[-1], num_channels=out_channels) self.batch_norm = nn.BatchNorm2d(num_nodes) def forward(self, X): t = self.temporal1(X) t = self.odeg(t) t = self.temporal2(F.relu(t)) return self.batch_norm(t) class GPT2BackboneHF(nn.Module): def __init__(self, model_name: str | None = None, gradient_checkpointing: bool = False, freeze: bool = False, local_dir: str | None = None): super().__init__() from transformers import GPT2Model if local_dir is not None and len(local_dir) > 0: self.model = GPT2Model.from_pretrained(local_dir, local_files_only=True, use_cache=False) else: if model_name is None: model_name = 'gpt2' self.model = GPT2Model.from_pretrained(model_name, use_cache=False) if gradient_checkpointing: self.model.gradient_checkpointing_enable() self.hidden_size = self.model.config.hidden_size if freeze: for p in self.model.parameters(): p.requires_grad = False def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: outputs = self.model(inputs_embeds=inputs_embeds) return outputs.last_hidden_state class ODEGCN_LLM_GPT2(nn.Module): def __init__(self, config): super(ODEGCN_LLM_GPT2, self).__init__() args = config['model'] num_nodes = config['data']['num_nodes'] num_features = args['num_features'] self.history = args['history'] self.horizon = args['horizon'] A_sp_hat, A_se_hat = get_A_hat(config) self.sp_blocks = nn.ModuleList( [nn.Sequential( STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat), STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3) ]) self.se_blocks = nn.ModuleList( [nn.Sequential( STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat), STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3) ]) # HF GPT-2 gpt2_name = args.get('gpt2_name', 'gpt2') grad_ckpt = bool(args.get('gpt2_grad_ckpt', False)) gpt2_freeze = bool(args.get('gpt2_freeze', False)) gpt2_local_dir = args.get('gpt2_local_dir', None) self.gpt2 = GPT2BackboneHF(gpt2_name, gradient_checkpointing=grad_ckpt, freeze=gpt2_freeze, local_dir=gpt2_local_dir) # Project ODE features to GPT-2 hidden size self.to_llm_embed = nn.Linear(64, self.gpt2.hidden_size) # Prediction head self.proj_head = nn.Sequential( nn.Linear(self.gpt2.hidden_size, self.gpt2.hidden_size), nn.ReLU(), nn.Linear(self.gpt2.hidden_size, self.horizon) ) def forward(self, x): x = x[..., 0:1].permute(0, 2, 1, 3) outs = [] for blk in self.sp_blocks: outs.append(blk(x)) for blk in self.se_blocks: outs.append(blk(x)) outs = torch.stack(outs) x = torch.max(outs, dim=0)[0] # (B, N, T, 64) B, N, T, C = x.shape x = self.to_llm_embed(x).view(B * N, T, -1) llm_hidden = self.gpt2(inputs_embeds=x) last_state = llm_hidden[:, -1, :] y = self.proj_head(last_state) y = y.view(B, N, self.horizon).permute(0, 2, 1).unsqueeze(-1) return y