添加STGODE_LLM和STGODE_LLM_GPT2模型实现,更新配置文件和README
This commit is contained in:
parent
f0d3460c89
commit
e9e3da03d3
|
|
@ -5,14 +5,10 @@
|
|||
</component>
|
||||
<component name="ChangeListManager">
|
||||
<list default="true" id="8b1aea27-342c-41a7-b776-2aba4fceda0d" name="更改" comment="">
|
||||
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/STDEN" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/STDEN/lib/utils.py" beforeDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/STDEN/stden_eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_eval.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/STDEN/stden_train.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_train.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/models/STGODE/STGODE.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/STGODE/STGODE.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/models/STGODE/adj.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/STGODE/adj.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/models/model_selector.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/model_selector.py" afterDir="false" />
|
||||
</list>
|
||||
<option name="SHOW_DIALOG" value="false" />
|
||||
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||
|
|
@ -66,7 +62,7 @@
|
|||
"Python.main.executor": "Run",
|
||||
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
||||
"RunOnceActivity.ShowReadmeOnStart": "true",
|
||||
"git-widget-placeholder": "STGODE",
|
||||
"git-widget-placeholder": "main",
|
||||
"last_opened_file_path": "/home/czzhangheng/code/Project-I/main.py",
|
||||
"node.js.detected.package.eslint": "true",
|
||||
"node.js.detected.package.tslint": "true",
|
||||
|
|
@ -93,11 +89,12 @@
|
|||
<window_info anchor="bottom" id="TODO" show_stripe_button="false" />
|
||||
<window_info anchor="bottom" id="File Transfer" show_stripe_button="false" />
|
||||
<window_info active="true" anchor="bottom" id="Run" visible="true" weight="0.32989067" />
|
||||
<window_info anchor="bottom" id="Find" />
|
||||
<window_info anchor="bottom" id="Version Control" order="0" />
|
||||
<window_info anchor="bottom" id="Problems" order="1" />
|
||||
<window_info anchor="bottom" id="Problems View" order="2" weight="0.33686176" />
|
||||
<window_info anchor="bottom" id="Terminal" order="3" weight="0.32989067" />
|
||||
<window_info anchor="bottom" id="Services" order="4" />
|
||||
<window_info active="true" anchor="bottom" id="Services" order="4" visible="true" weight="0.32989067" />
|
||||
<window_info anchor="bottom" id="Python Packages" order="5" weight="0.1" />
|
||||
<window_info anchor="bottom" id="Python Console" order="6" weight="0.1" />
|
||||
<window_info anchor="right" id="Endpoints" show_stripe_button="false" />
|
||||
|
|
@ -211,6 +208,7 @@
|
|||
</component>
|
||||
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
||||
<SUITE FILE_PATH="coverage/Project_I$main.coverage" NAME="STDEN 覆盖结果" MODIFIED="1756832980407" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
||||
<SUITE FILE_PATH="coverage/Project_I$STGODE.coverage" NAME="STGODE 覆盖结果" MODIFIED="1756864828915" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
||||
<SUITE FILE_PATH="coverage/Project_I$STGODE_LLM.coverage" NAME="STGODE-LLM 覆盖结果" MODIFIED="1756950739801" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
||||
<SUITE FILE_PATH="coverage/Project_I$STGODE.coverage" NAME="STGODE 覆盖结果" MODIFIED="1756885209907" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -1,3 +1,9 @@
|
|||
# Project-I
|
||||
|
||||
Secret Projct
|
||||
|
||||
mkdir -p models/gpt2
|
||||
|
||||
Download config.json & pytorch_model.bin from https://huggingface.co/openai-community/gpt2/tree/main
|
||||
|
||||
Use pytorch >= 2.6 to load model.
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
basic:
|
||||
device: cuda:0
|
||||
dataset: PEMS08
|
||||
model: STGODE-LLM
|
||||
mode: test
|
||||
seed: 2025
|
||||
|
||||
data:
|
||||
dataset_dir: data/PEMS08
|
||||
val_batch_size: 32
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
num_nodes: 170
|
||||
batch_size: 64
|
||||
input_dim: 1
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
steps_per_day: 24
|
||||
days_per_week: 7
|
||||
|
||||
model:
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
history: 12
|
||||
horizon: 12
|
||||
num_features: 1
|
||||
rnn_units: 64
|
||||
sigma1: 0.1
|
||||
sigma2: 10
|
||||
thres1: 0.6
|
||||
thres2: 0.5
|
||||
# LLM backbone settings
|
||||
llm_hidden: 128
|
||||
llm_layers: 4
|
||||
llm_heads: 4
|
||||
llm_pretrained: True
|
||||
|
||||
train:
|
||||
loss: mae
|
||||
batch_size: 64
|
||||
epochs: 100
|
||||
lr_init: 0.003
|
||||
mape_thresh: 0.001
|
||||
mae_thresh: None
|
||||
debug: False
|
||||
output_dim: 1
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: "5,20,40,70"
|
||||
early_stop: True
|
||||
early_stop_patience: 15
|
||||
grad_norm: False
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
log_step: 3000
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
basic:
|
||||
device: cuda:0
|
||||
dataset: PEMS08
|
||||
model: STGODE-LLM-GPT2
|
||||
mode: train
|
||||
seed: 2025
|
||||
|
||||
data:
|
||||
dataset_dir: data/PEMS08
|
||||
val_batch_size: 16
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
num_nodes: 170
|
||||
batch_size: 32
|
||||
input_dim: 1
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
steps_per_day: 24
|
||||
days_per_week: 7
|
||||
|
||||
model:
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
history: 12
|
||||
horizon: 12
|
||||
num_features: 1
|
||||
rnn_units: 64
|
||||
sigma1: 0.1
|
||||
sigma2: 10
|
||||
thres1: 0.6
|
||||
thres2: 0.5
|
||||
# HF GPT-2 settings
|
||||
gpt2_name: gpt2
|
||||
gpt2_grad_ckpt: True
|
||||
gpt2_freeze: True
|
||||
gpt2_local_dir: ./models/gpt2
|
||||
|
||||
train:
|
||||
loss: mae
|
||||
batch_size: 32
|
||||
epochs: 100
|
||||
lr_init: 0.0003
|
||||
mape_thresh: 0.001
|
||||
mae_thresh: None
|
||||
debug: False
|
||||
output_dim: 1
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: "10,30,60,90"
|
||||
early_stop: True
|
||||
early_stop_patience: 15
|
||||
grad_norm: False
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
log_step: 3000
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
import torch
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.STGODE.odegcn import ODEG
|
||||
from models.STGODE.adj import get_A_hat
|
||||
|
||||
|
||||
class Chomp1d(nn.Module):
|
||||
def __init__(self, chomp_size):
|
||||
super(Chomp1d, self).__init__()
|
||||
self.chomp_size = chomp_size
|
||||
|
||||
def forward(self, x):
|
||||
return x[:, :, :, :-self.chomp_size].contiguous()
|
||||
|
||||
|
||||
class TemporalConvNet(nn.Module):
|
||||
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
|
||||
super(TemporalConvNet, self).__init__()
|
||||
layers = []
|
||||
num_levels = len(num_channels)
|
||||
for i in range(num_levels):
|
||||
dilation_size = 2 ** i
|
||||
in_channels = num_inputs if i == 0 else num_channels[i - 1]
|
||||
out_channels = num_channels[i]
|
||||
padding = (kernel_size - 1) * dilation_size
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, (1, kernel_size), dilation=(1, dilation_size),
|
||||
padding=(0, padding))
|
||||
self.conv.weight.data.normal_(0, 0.01)
|
||||
self.chomp = Chomp1d(padding)
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
|
||||
|
||||
self.network = nn.Sequential(*layers)
|
||||
self.downsample = nn.Conv2d(num_inputs, num_channels[-1], (1, 1)) if num_inputs != num_channels[-1] else None
|
||||
if self.downsample:
|
||||
self.downsample.weight.data.normal_(0, 0.01)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.permute(0, 3, 1, 2)
|
||||
y = F.relu(self.network(y) + self.downsample(y) if self.downsample else y)
|
||||
y = y.permute(0, 2, 3, 1)
|
||||
return y
|
||||
|
||||
|
||||
class STGCNBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_nodes, A_hat):
|
||||
super(STGCNBlock, self).__init__()
|
||||
self.A_hat = A_hat
|
||||
self.temporal1 = TemporalConvNet(num_inputs=in_channels, num_channels=out_channels)
|
||||
self.odeg = ODEG(out_channels[-1], 12, A_hat, time=6)
|
||||
self.temporal2 = TemporalConvNet(num_inputs=out_channels[-1], num_channels=out_channels)
|
||||
self.batch_norm = nn.BatchNorm2d(num_nodes)
|
||||
|
||||
def forward(self, X):
|
||||
t = self.temporal1(X)
|
||||
t = self.odeg(t)
|
||||
t = self.temporal2(F.relu(t))
|
||||
return self.batch_norm(t)
|
||||
|
||||
|
||||
class GPT2Backbone(nn.Module):
|
||||
def __init__(self, hidden_size: int, n_layer: int = 4, n_head: int = 4, n_embd: int | None = None, use_pretrained: bool = True):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.use_transformers = False
|
||||
self.model = None
|
||||
if n_embd is None:
|
||||
n_embd = hidden_size
|
||||
if use_pretrained:
|
||||
try:
|
||||
from transformers import GPT2Model, GPT2Config
|
||||
config = GPT2Config(n_embd=n_embd, n_layer=n_layer, n_head=n_head, n_positions=1024, n_ctx=1024, vocab_size=1)
|
||||
self.model = GPT2Model(config)
|
||||
self.use_transformers = True
|
||||
except Exception:
|
||||
self.use_transformers = False
|
||||
if not self.use_transformers:
|
||||
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=n_head, batch_first=True)
|
||||
self.model = nn.TransformerEncoder(encoder_layer, num_layers=n_layer)
|
||||
|
||||
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_transformers:
|
||||
outputs = self.model(inputs_embeds=inputs_embeds)
|
||||
return outputs.last_hidden_state
|
||||
else:
|
||||
return self.model(inputs_embeds)
|
||||
|
||||
|
||||
class ODEGCN_LLM(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(ODEGCN_LLM, self).__init__()
|
||||
args = config['model']
|
||||
num_nodes = config['data']['num_nodes']
|
||||
num_features = args['num_features']
|
||||
num_timesteps_input = args['history']
|
||||
num_timesteps_output = args['horizon']
|
||||
A_sp_hat, A_se_hat = get_A_hat(config)
|
||||
|
||||
self.sp_blocks = nn.ModuleList(
|
||||
[nn.Sequential(
|
||||
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat),
|
||||
STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3)
|
||||
])
|
||||
|
||||
self.se_blocks = nn.ModuleList(
|
||||
[nn.Sequential(
|
||||
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat),
|
||||
STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3)
|
||||
])
|
||||
|
||||
self.history = num_timesteps_input
|
||||
self.horizon = num_timesteps_output
|
||||
|
||||
hidden_size = int(args.get('llm_hidden', 128))
|
||||
llm_layers = int(args.get('llm_layers', 4))
|
||||
llm_heads = int(args.get('llm_heads', 4))
|
||||
use_pretrained = bool(args.get('llm_pretrained', True))
|
||||
|
||||
self.to_llm_embed = nn.Linear(64, hidden_size)
|
||||
self.gpt2 = GPT2Backbone(hidden_size=hidden_size, n_layer=llm_layers, n_head=llm_heads, use_pretrained=use_pretrained)
|
||||
self.proj_head = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size, self.horizon)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x[..., 0:1].permute(0, 2, 1, 3)
|
||||
outs = []
|
||||
for blk in self.sp_blocks:
|
||||
outs.append(blk(x))
|
||||
for blk in self.se_blocks:
|
||||
outs.append(blk(x))
|
||||
outs = torch.stack(outs)
|
||||
x = torch.max(outs, dim=0)[0]
|
||||
|
||||
# x: (B, N, T, 64) physical quantities after ODE-based transform
|
||||
B, N, T, C = x.shape
|
||||
x = self.to_llm_embed(x) # (B, N, T, H)
|
||||
x = x.permute(0, 1, 2, 3).contiguous().view(B * N, T, -1) # (B*N, T, H)
|
||||
|
||||
llm_hidden = self.gpt2(inputs_embeds=x) # (B*N, T, H)
|
||||
last_state = llm_hidden[:, -1, :] # (B*N, H)
|
||||
y = self.proj_head(last_state) # (B*N, horizon)
|
||||
y = y.view(B, N, self.horizon).permute(0, 2, 1).unsqueeze(-1) # (B, horizon, N, 1)
|
||||
return y
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
from .STGODE_LLM import ODEGCN_LLM
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.STGODE.odegcn import ODEG
|
||||
from models.STGODE.adj import get_A_hat
|
||||
|
||||
|
||||
class Chomp1d(nn.Module):
|
||||
def __init__(self, chomp_size):
|
||||
super(Chomp1d, self).__init__()
|
||||
self.chomp_size = chomp_size
|
||||
|
||||
def forward(self, x):
|
||||
return x[:, :, :, :-self.chomp_size].contiguous()
|
||||
|
||||
|
||||
class TemporalConvNet(nn.Module):
|
||||
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
|
||||
super(TemporalConvNet, self).__init__()
|
||||
layers = []
|
||||
num_levels = len(num_channels)
|
||||
for i in range(num_levels):
|
||||
dilation_size = 2 ** i
|
||||
in_channels = num_inputs if i == 0 else num_channels[i - 1]
|
||||
out_channels = num_channels[i]
|
||||
padding = (kernel_size - 1) * dilation_size
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, (1, kernel_size), dilation=(1, dilation_size),
|
||||
padding=(0, padding))
|
||||
self.conv.weight.data.normal_(0, 0.01)
|
||||
self.chomp = Chomp1d(padding)
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
|
||||
|
||||
self.network = nn.Sequential(*layers)
|
||||
self.downsample = nn.Conv2d(num_inputs, num_channels[-1], (1, 1)) if num_inputs != num_channels[-1] else None
|
||||
if self.downsample:
|
||||
self.downsample.weight.data.normal_(0, 0.01)
|
||||
|
||||
def forward(self, x):
|
||||
y = x.permute(0, 3, 1, 2)
|
||||
y = F.relu(self.network(y) + self.downsample(y) if self.downsample else y)
|
||||
y = y.permute(0, 2, 3, 1)
|
||||
return y
|
||||
|
||||
|
||||
class STGCNBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_nodes, A_hat):
|
||||
super(STGCNBlock, self).__init__()
|
||||
self.A_hat = A_hat
|
||||
self.temporal1 = TemporalConvNet(num_inputs=in_channels, num_channels=out_channels)
|
||||
self.odeg = ODEG(out_channels[-1], 12, A_hat, time=6)
|
||||
self.temporal2 = TemporalConvNet(num_inputs=out_channels[-1], num_channels=out_channels)
|
||||
self.batch_norm = nn.BatchNorm2d(num_nodes)
|
||||
|
||||
def forward(self, X):
|
||||
t = self.temporal1(X)
|
||||
t = self.odeg(t)
|
||||
t = self.temporal2(F.relu(t))
|
||||
return self.batch_norm(t)
|
||||
|
||||
|
||||
class GPT2BackboneHF(nn.Module):
|
||||
def __init__(self, model_name: str | None = None, gradient_checkpointing: bool = False, freeze: bool = False, local_dir: str | None = None):
|
||||
super().__init__()
|
||||
from transformers import GPT2Model
|
||||
if local_dir is not None and len(local_dir) > 0:
|
||||
self.model = GPT2Model.from_pretrained(local_dir, local_files_only=True)
|
||||
else:
|
||||
if model_name is None:
|
||||
model_name = 'gpt2'
|
||||
self.model = GPT2Model.from_pretrained(model_name)
|
||||
if gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
self.hidden_size = self.model.config.hidden_size
|
||||
if freeze:
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
|
||||
outputs = self.model(inputs_embeds=inputs_embeds)
|
||||
return outputs.last_hidden_state
|
||||
|
||||
|
||||
class ODEGCN_LLM_GPT2(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(ODEGCN_LLM_GPT2, self).__init__()
|
||||
args = config['model']
|
||||
num_nodes = config['data']['num_nodes']
|
||||
num_features = args['num_features']
|
||||
self.history = args['history']
|
||||
self.horizon = args['horizon']
|
||||
A_sp_hat, A_se_hat = get_A_hat(config)
|
||||
|
||||
self.sp_blocks = nn.ModuleList(
|
||||
[nn.Sequential(
|
||||
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat),
|
||||
STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3)
|
||||
])
|
||||
|
||||
self.se_blocks = nn.ModuleList(
|
||||
[nn.Sequential(
|
||||
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat),
|
||||
STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3)
|
||||
])
|
||||
|
||||
# HF GPT-2
|
||||
gpt2_name = args.get('gpt2_name', 'gpt2')
|
||||
grad_ckpt = bool(args.get('gpt2_grad_ckpt', False))
|
||||
gpt2_freeze = bool(args.get('gpt2_freeze', False))
|
||||
gpt2_local_dir = args.get('gpt2_local_dir', None)
|
||||
self.gpt2 = GPT2BackboneHF(gpt2_name, gradient_checkpointing=grad_ckpt, freeze=gpt2_freeze, local_dir=gpt2_local_dir)
|
||||
|
||||
# Project ODE features to GPT-2 hidden size
|
||||
self.to_llm_embed = nn.Linear(64, self.gpt2.hidden_size)
|
||||
|
||||
# Prediction head
|
||||
self.proj_head = nn.Sequential(
|
||||
nn.Linear(self.gpt2.hidden_size, self.gpt2.hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.gpt2.hidden_size, self.horizon)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x[..., 0:1].permute(0, 2, 1, 3)
|
||||
outs = []
|
||||
for blk in self.sp_blocks:
|
||||
outs.append(blk(x))
|
||||
for blk in self.se_blocks:
|
||||
outs.append(blk(x))
|
||||
outs = torch.stack(outs)
|
||||
x = torch.max(outs, dim=0)[0] # (B, N, T, 64)
|
||||
|
||||
B, N, T, C = x.shape
|
||||
x = self.to_llm_embed(x).view(B * N, T, -1)
|
||||
|
||||
llm_hidden = self.gpt2(inputs_embeds=x)
|
||||
last_state = llm_hidden[:, -1, :]
|
||||
y = self.proj_head(last_state)
|
||||
y = y.view(B, N, self.horizon).permute(0, 2, 1).unsqueeze(-1)
|
||||
return y
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
from .STGODE_LLM_GPT2 import ODEGCN_LLM_GPT2
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
BIN
test_spatial.npy
BIN
test_spatial.npy
Binary file not shown.
Loading…
Reference in New Issue