Compare commits

..

7 Commits

Author SHA1 Message Date
czzhangheng af795043c8 Merge branch 'main' into dev
# Conflicts:
#	run.py
#	trainer/Trainer.py
2025-08-18 21:07:16 +08:00
czzhangheng bef30b9c2f 解决合并冲突,整合dev和main分支的更改 2025-05-14 13:13:11 +08:00
czzhangheng c15cf605be 添加md结果 2025-04-26 15:37:40 +08:00
czzhangheng 5f8c31af2e 添加STIDGCN 2025-04-23 23:24:43 +08:00
czzhangheng e826240a5e Merge remote-tracking branch 'origin/main'
# Conflicts:
#	model/model_selector.py
2025-04-23 23:24:05 +08:00
czzhangheng 97eb39073a 添加STIDGCN 2025-04-23 23:22:50 +08:00
czzhangheng bc9a2667c2 新增了模型蒸馏STMLP
现在Trainer每次epoch完后都会保存模型checkpoint
其中STMLP会自动教师模型到pre-train
根据教师模型的存在情况启动/预训练or蒸馏模式
2025-04-07 17:05:59 +08:00
20 changed files with 1446 additions and 30 deletions

1
.gitignore vendored
View File

@ -7,6 +7,7 @@ experiments/
*.pkl
data/
pretrain/
pre-train/
# ---> Python
# Byte-compiled / optimized / DLL files

19
Result.md Normal file
View File

@ -0,0 +1,19 @@
| NO. | Baselines | PEMS03 MAE | PEMS03 RMSE | PEMS03 MAPE | PEMS04 MAE | PEMS04 RMSE | PEMS04 MAPE | PEMS07 MAE | PEMS07 RMSE | PEMS07 MAPE | PEMS08 MAE | PEMS08 RMSE | PEMS08 MAPE | 备注 |
|-----|----------------|------------|-------------|-------------|------------|-------------|-------------|------------|-------------|-------------|------------|-------------|-------------|--------|
| 1 | HA | | | | | | | | | | | | | 未实现 |
| 2 | ARIMA | 30.99 | 48.28 | 28.66% | 39.7 | 59.12 | 27.57% | / | / | / | 32.51 | 48.5 | 19.94% | 偏高 |
| 3 | VAR | | | | | | | | | | | | | 未实现 |
| 4 | FC-LSTM | | | | | | | | | | | | | 未实现 |
| 5 | TCN | 29.51 | 45.79 | 29.11% | 37.6 | 55.5 | 26.81% | 42.6 | 62.19 | 20.22% | 31.18 | 45.8 | 20.64% | 偏高 |
| 6 | GRU-ED | | | | | | | | | | | | | 未实现 |
| 7 | DSANET | 21.26 | 34.44 | 21.18% | 27.77 | 43.89 | 18.88% | 31.59 | 49.42 | 13.93% | 22.38 | 35.48 | 14.26% | 合理 |
| 8 | STGCN | 17.41 | 29.31 | 18.91% | 20.58 | 32.7 | 14.75% | 23.17 | 36.73 | 10.54% | 18.05 | 27.69 | 13.67% | 合理 |
| 9 | DCRNN | 39.62 | 64.18 | 64.05% | 44.14 | 64.21 | 44.59% | 52.78 | 82.99 | 43.32% | 45.27 | 69.25 | 52.85% | 偏高 |
| 10 | GraphWaveNet | 14.68 | 25.86 | 14.38% | 19.19 | 31.04 | 13.06% | 20.40 | 33.48 | 8.73% | 14.83 | 23.86 | 10.14% | 偏低 |
| 11 | STSGCN | 18.41 | 30.77 | 19.28% | 21.4 | 35.04 | 14.28% | 24.47 | 38.96 | 10.77% | 17.58 | 27.19 | 12.00% | 合理 |
| 12 | AGCRN | 15.21 | 26.52 | 14.71% | 19.28 | 31.35 | 12.98% | 20.46 | 33.79 | 8.70% | 15.76 | 25.23 | 10.25% | 合理 |
| 13 | STFGNN | 17.29 | 29.56 | 17.48% | 23.06 | 36.23 | 15.52% | 24.67 | 38.93 | 10.89% | 16.87 | 27.48 | 11.16% | 合理 |
| 14 | STGODE | 16.55 | 26.62 | 17.58% | 22.55 | 35.05 | 15.91% | 23.28 | 26.19 | 10.97% | 17.22 | 26.66 | 11.52% | 合理 |
| 15 | STG-NCDE | 16.09 | 26.78 | 16.58% | 19.82 | 31.71 | 13.21% | 22.54 | 35.44 | 9.85% | 15.85 | 25.05 | 10.19% | 合理 |
| 16 | DDGCRN | 14.51 | 24.83 | 14.51% | 18.34 | 30.77 | 12.17% | 19.68 | 33.40 | 8.23% | 14.39 | 23.75 | 9.42% | 偏低 |
| 17 | TWDGCN | 14.65 | 24.84 | 14.66% | 18.54 | 30.53 | 12.29% | 20.01 | 33.62 | 8.50% | 14.65 | 24.19 | 9.51% | 偏高 |

View File

@ -0,0 +1,48 @@
data:
num_nodes: 358
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 3
output_dim: 1
history: 12
horizon: 12
granularity: 288
dropout: 0.1
channels: 32
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 200
plot: False

View File

@ -0,0 +1,48 @@
data:
num_nodes: 307
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 3
output_dim: 1
history: 12
horizon: 12
granularity: 288
dropout: 0.1
channels: 64
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 200
plot: False

View File

@ -0,0 +1,48 @@
data:
num_nodes: 883
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 3
output_dim: 1
history: 12
horizon: 12
granularity: 288
dropout: 0.1
channels: 128
train:
loss_func: mae
seed: 10
batch_size: 16
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 200
plot: False

View File

@ -0,0 +1,48 @@
data:
num_nodes: 170
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 3
output_dim: 1
history: 12
horizon: 12
granularity: 288
dropout: 0.1
channels: 96
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 200
plot: False

66
config/STMLP/PEMSD3.yaml Normal file
View File

@ -0,0 +1,66 @@
data:
num_nodes: 358
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 1
output_dim: 1
input_window: 12
output_window: 12
gcn_true: true
buildA_true: true
gcn_depth: 2
dropout: 0.3
subgraph_size: 20
node_dim: 40
dilation_exponential: 1
conv_channels: 32
residual_channels: 32
skip_channels: 64
end_channels: 128
layers: 3
propalpha: 0.05
tanhalpha: 3
layer_norm_affline: true
use_curriculum_learning: true
step_size1: 2500
task_level: 0
num_split: 1
step_size2: 100
model_type: stmlp
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
teacher_stu: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 2000
plot: False

67
config/STMLP/PEMSD4.yaml Normal file
View File

@ -0,0 +1,67 @@
data:
num_nodes: 307
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 1
output_dim: 1
input_window: 12
output_window: 12
gcn_true: true
buildA_true: true
gcn_depth: 2
dropout: 0.3
subgraph_size: 20
node_dim: 40
dilation_exponential: 1
conv_channels: 32
residual_channels: 32
skip_channels: 64
end_channels: 128
layers: 3
propalpha: 0.05
tanhalpha: 3
layer_norm_affline: true
use_curriculum_learning: true
step_size1: 2500
task_level: 0
num_split: 1
step_size2: 100
model_type: stmlp
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
teacher: True
teacher_stu: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 2000
plot: False

66
config/STMLP/PEMSD7.yaml Normal file
View File

@ -0,0 +1,66 @@
data:
num_nodes: 883
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 1
output_dim: 1
input_window: 12
output_window: 12
gcn_true: true
buildA_true: true
gcn_depth: 2
dropout: 0.3
subgraph_size: 20
node_dim: 40
dilation_exponential: 1
conv_channels: 32
residual_channels: 32
skip_channels: 64
end_channels: 128
layers: 3
propalpha: 0.05
tanhalpha: 3
layer_norm_affline: true
use_curriculum_learning: true
step_size1: 2500
task_level: 0
num_split: 1
step_size2: 100
model_type: stmlp
train:
loss_func: mae
seed: 10
batch_size: 16
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
teacher_stu: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 2000
plot: False

66
config/STMLP/PEMSD8.yaml Normal file
View File

@ -0,0 +1,66 @@
data:
num_nodes: 170
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 1
output_dim: 1
input_window: 12
output_window: 12
gcn_true: true
buildA_true: true
gcn_depth: 2
dropout: 0.3
subgraph_size: 20
node_dim: 40
dilation_exponential: 1
conv_channels: 32
residual_channels: 32
skip_channels: 64
end_channels: 128
layers: 3
propalpha: 0.05
tanhalpha: 3
layer_norm_affline: true
use_curriculum_learning: true
step_size1: 2500
task_level: 0
num_split: 1
step_size2: 100
model_type: stmlp
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
teacher_stu: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 2000
plot: False

View File

@ -121,7 +121,7 @@ def download_kaggle_data(current_dir):
如果目标文件夹已存在会覆盖冲突的文件
"""
try:
print("正在下载 KaggleHub 数据集...")
print("正在下载 PEMS 数据集...")
path = kagglehub.dataset_download("elmahy/pems-dataset")
# print("Path to KaggleHub dataset files:", path)

368
model/STIDGCN/STIDGCN.py Normal file
View File

@ -0,0 +1,368 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class GLU(nn.Module):
def __init__(self, features, dropout=0.1):
super(GLU, self).__init__()
self.conv1 = nn.Conv2d(features, features, (1, 1))
self.conv2 = nn.Conv2d(features, features, (1, 1))
self.conv3 = nn.Conv2d(features, features, (1, 1))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
out = x1 * torch.sigmoid(x2)
out = self.dropout(out)
out = self.conv3(out)
return out
# class TemporalEmbedding(nn.Module):
# def __init__(self, time, features):
# super(TemporalEmbedding, self).__init__()
#
# self.time = time
# # self.time_day = nn.Parameter(torch.empty(time, features))
# # nn.init.xavier_uniform_(self.time_day)
# #
# # self.time_week = nn.Parameter(torch.empty(7, features))
# # nn.init.xavier_uniform_(self.time_week)
# self.time_day = nn.Embedding(time, features)
# self.time_week = nn.Embedding(7, features)
#
# def forward(self, x):
# day_emb = x[..., 1]
# # time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)]
# # time_day = time_day.transpose(1, 2).contiguous()
#
# week_emb = x[..., 2]
# # time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)]
# # time_week = time_week.transpose(1, 2).contiguous()
#
# t_idx = (day_emb[:, -1, :, ] * (self.time - 1)).long() # (B, N)
# d_idx = week_emb[:, -1, :, ].long() # (B, N)
# # time_emb = self.time_embedding(t_idx) # (B, N, hidden_dim)
# # day_emb = self.day_embedding(d_idx) # (B, N, hidden_dim)
#
# tem_emb = t_idx + d_idx
#
# # tem_emb = tem_emb.permute(0, 3, 1, 2)
#
# return tem_emb
class TemporalEmbedding(nn.Module):
def __init__(self, time, features):
super(TemporalEmbedding, self).__init__()
self.time = time
self.time_day = nn.Parameter(torch.empty(time, features))
nn.init.xavier_uniform_(self.time_day)
self.time_week = nn.Parameter(torch.empty(7, features))
nn.init.xavier_uniform_(self.time_week)
def forward(self, x):
day_emb = x[..., 1]
time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)]
time_day = time_day.transpose(1, 2).contiguous()
week_emb = x[..., 2]
time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)]
time_week = time_week.transpose(1, 2).contiguous()
tem_emb = time_day + time_week
tem_emb = tem_emb.permute(0,3,1,2)
return tem_emb
class Diffusion_GCN(nn.Module):
def __init__(self, channels=128, diffusion_step=1, dropout=0.1):
super().__init__()
self.diffusion_step = diffusion_step
self.conv = nn.Conv2d(diffusion_step * channels, channels, (1, 1))
self.dropout = nn.Dropout(dropout)
def forward(self, x, adj):
out = []
for i in range(0, self.diffusion_step):
if adj.dim() == 3:
x = torch.einsum("bcnt,bnm->bcmt", x, adj).contiguous()
out.append(x)
elif adj.dim() == 2:
x = torch.einsum("bcnt,nm->bcmt", x, adj).contiguous()
out.append(x)
x = torch.cat(out, dim=1)
x = self.conv(x)
output = self.dropout(x)
return output
class Graph_Generator(nn.Module):
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1):
super().__init__()
self.memory = nn.Parameter(torch.randn(channels, num_nodes))
nn.init.xavier_uniform_(self.memory)
self.fc = nn.Linear(2, 1)
def forward(self, x):
adj_dyn_1 = torch.softmax(
F.relu(
torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous()
/ math.sqrt(x.shape[1])
),
-1,
)
adj_dyn_2 = torch.softmax(
F.relu(
torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous()
/ math.sqrt(x.shape[1])
),
-1,
)
# adj_dyn = (adj_dyn_1 + adj_dyn_2 + adj)/2
adj_f = torch.cat([(adj_dyn_1).unsqueeze(-1)] + [(adj_dyn_2).unsqueeze(-1)], dim=-1)
adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1)
topk_values, topk_indices = torch.topk(adj_f, k=int(adj_f.shape[1] * 0.8), dim=-1)
mask = torch.zeros_like(adj_f)
mask.scatter_(-1, topk_indices, 1)
adj_f = adj_f * mask
return adj_f
class DGCN(nn.Module):
def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None):
super().__init__()
self.conv = nn.Conv2d(channels, channels, (1, 1))
self.generator = Graph_Generator(channels, num_nodes, diffusion_step, dropout)
self.gcn = Diffusion_GCN(channels, diffusion_step, dropout)
self.emb = emb
def forward(self, x):
skip = x
x = self.conv(x)
adj_dyn = self.generator(x)
x = self.gcn(x, adj_dyn)
x = x * self.emb + skip
return x
class Splitting(nn.Module):
def __init__(self):
super(Splitting, self).__init__()
def even(self, x):
return x[:, :, :, ::2]
def odd(self, x):
return x[:, :, :, 1::2]
def forward(self, x):
return (self.even(x), self.odd(x))
class IDGCN(nn.Module):
def __init__(
self,
device,
channels=64,
diffusion_step=1,
splitting=True,
num_nodes=170,
dropout=0.2, emb=None
):
super(IDGCN, self).__init__()
device = device
self.dropout = dropout
self.num_nodes = num_nodes
self.splitting = splitting
self.split = Splitting()
Conv1 = []
Conv2 = []
Conv3 = []
Conv4 = []
pad_l = 3
pad_r = 3
k1 = 5
k2 = 3
Conv1 += [
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Dropout(self.dropout),
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
nn.Tanh(),
]
Conv2 += [
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Dropout(self.dropout),
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
nn.Tanh(),
]
Conv4 += [
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Dropout(self.dropout),
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
nn.Tanh(),
]
Conv3 += [
nn.ReplicationPad2d((pad_l, pad_r, 0, 0)),
nn.Conv2d(channels, channels, kernel_size=(1, k1)),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Dropout(self.dropout),
nn.Conv2d(channels, channels, kernel_size=(1, k2)),
nn.Tanh(),
]
self.conv1 = nn.Sequential(*Conv1)
self.conv2 = nn.Sequential(*Conv2)
self.conv3 = nn.Sequential(*Conv3)
self.conv4 = nn.Sequential(*Conv4)
self.dgcn = DGCN(channels, num_nodes, diffusion_step, dropout, emb)
def forward(self, x):
if self.splitting:
(x_even, x_odd) = self.split(x)
else:
(x_even, x_odd) = x
x1 = self.conv1(x_even)
x1 = self.dgcn(x1)
d = x_odd.mul(torch.tanh(x1))
x2 = self.conv2(x_odd)
x2 = self.dgcn(x2)
c = x_even.mul(torch.tanh(x2))
x3 = self.conv3(c)
x3 = self.dgcn(x3)
x_odd_update = d + x3
x4 = self.conv4(d)
x4 = self.dgcn(x4)
x_even_update = c + x4
return (x_even_update, x_odd_update)
class IDGCN_Tree(nn.Module):
def __init__(
self, device, channels=64, diffusion_step=1, num_nodes=170, dropout=0.1
):
super().__init__()
self.memory1 = nn.Parameter(torch.randn(channels, num_nodes, 6))
self.memory2 = nn.Parameter(torch.randn(channels, num_nodes, 3))
self.memory3 = nn.Parameter(torch.randn(channels, num_nodes, 3))
self.IDGCN1 = IDGCN(
device=device,
splitting=True,
channels=channels,
diffusion_step=diffusion_step,
num_nodes=num_nodes,
dropout=dropout, emb=self.memory1
)
self.IDGCN2 = IDGCN(
device=device,
splitting=True,
channels=channels,
diffusion_step=diffusion_step,
num_nodes=num_nodes,
dropout=dropout, emb=self.memory2
)
self.IDGCN3 = IDGCN(
device=device,
splitting=True,
channels=channels,
diffusion_step=diffusion_step,
num_nodes=num_nodes,
dropout=dropout, emb=self.memory2
)
def concat(self, even, odd):
even = even.permute(3, 1, 2, 0)
odd = odd.permute(3, 1, 2, 0)
len = even.shape[0]
_ = []
for i in range(len):
_.append(even[i].unsqueeze(0))
_.append(odd[i].unsqueeze(0))
return torch.cat(_, 0).permute(3, 1, 2, 0)
def forward(self, x):
x_even_update1, x_odd_update1 = self.IDGCN1(x)
x_even_update2, x_odd_update2 = self.IDGCN2(x_even_update1)
x_even_update3, x_odd_update3 = self.IDGCN3(x_odd_update1)
concat1 = self.concat(x_even_update2, x_odd_update2)
concat2 = self.concat(x_even_update3, x_odd_update3)
concat0 = self.concat(concat1, concat2)
output = concat0 + x
return output
class STIDGCN(nn.Module):
def __init__(self, args):
"""
device, input_dim, num_nodes, channels, granularity, dropout=0.1
"""
super().__init__()
device = args['device']
input_dim = args['input_dim']
self.num_nodes = args['num_nodes']
self.output_len = 12
channels = args['channels']
granularity = args['granularity']
dropout = args['dropout']
diffusion_step = 1
self.Temb = TemporalEmbedding(granularity, channels)
self.start_conv = nn.Conv2d(
in_channels=input_dim, out_channels=channels, kernel_size=(1, 1)
)
self.tree = IDGCN_Tree(
device=device,
channels=channels * 2,
diffusion_step=diffusion_step,
num_nodes=self.num_nodes,
dropout=dropout,
)
self.glu = GLU(channels * 2, dropout)
self.regression_layer = nn.Conv2d(
channels * 2, self.output_len, kernel_size=(1, self.output_len)
)
def param_num(self):
return sum([param.nelement() for param in self.parameters()])
def forward(self, input):
input = input.transpose(1, 3)
x = input
# Encoder
# Data Embedding
time_emb = self.Temb(input.permute(0, 3, 2, 1))
x = torch.cat([self.start_conv(x)] + [time_emb], dim=1)
# IDGCN_Tree
x = self.tree(x)
# Decoder
gcn = self.glu(x) + x
prediction = self.regression_layer(F.relu(gcn))
return prediction

307
model/STMLP/STMLP.py Normal file
View File

@ -0,0 +1,307 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from data.get_adj import get_adj
import numbers
# --- 基础算子 ---
class NConv(nn.Module):
def forward(self, x, adj):
return torch.einsum('ncwl,vw->ncvl', (x, adj)).contiguous()
class DyNconv(nn.Module):
def forward(self, x, adj):
return torch.einsum('ncvl,nvwl->ncwl', (x, adj)).contiguous()
class Linear(nn.Module):
def __init__(self, c_in, c_out, bias=True):
super().__init__()
self.mlp = nn.Conv2d(c_in, c_out, kernel_size=1, bias=bias)
def forward(self, x):
return self.mlp(x)
class Prop(nn.Module):
def __init__(self, c_in, c_out, gdep, dropout, alpha):
super().__init__()
self.nconv = NConv()
self.mlp = Linear(c_in, c_out)
self.gdep, self.dropout, self.alpha = gdep, dropout, alpha
def forward(self, x, adj):
adj = adj + torch.eye(adj.size(0), device=x.device)
d = adj.sum(1)
a = adj / d.view(-1, 1)
h = x
for _ in range(self.gdep):
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a)
return self.mlp(h)
class MixProp(nn.Module):
def __init__(self, c_in, c_out, gdep, dropout, alpha):
super().__init__()
self.nconv = NConv()
self.mlp = Linear((gdep + 1) * c_in, c_out)
self.gdep, self.dropout, self.alpha = gdep, dropout, alpha
def forward(self, x, adj):
adj = adj + torch.eye(adj.size(0), device=x.device)
d = adj.sum(1)
a = adj / d.view(-1, 1)
out = [x]
h = x
for _ in range(self.gdep):
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a)
out.append(h)
return self.mlp(torch.cat(out, dim=1))
class DyMixprop(nn.Module):
def __init__(self, c_in, c_out, gdep, dropout, alpha):
super().__init__()
self.nconv = DyNconv()
self.mlp1 = Linear((gdep + 1) * c_in, c_out)
self.mlp2 = Linear((gdep + 1) * c_in, c_out)
self.gdep, self.dropout, self.alpha = gdep, dropout, alpha
self.lin1, self.lin2 = Linear(c_in, c_in), Linear(c_in, c_in)
def forward(self, x):
x1 = torch.tanh(self.lin1(x))
x2 = torch.tanh(self.lin2(x))
adj = self.nconv(x1.transpose(2, 1), x2)
adj0 = torch.softmax(adj, dim=2)
adj1 = torch.softmax(adj.transpose(2, 1), dim=2)
# 两条分支
out1, out2 = [x], [x]
h = x
for _ in range(self.gdep):
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj0)
out1.append(h)
h = x
for _ in range(self.gdep):
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1)
out2.append(h)
return self.mlp1(torch.cat(out1, dim=1)) + self.mlp2(torch.cat(out2, dim=1))
class DilatedInception(nn.Module):
def __init__(self, cin, cout, dilation_factor=2):
super().__init__()
self.kernels = [2, 3, 6, 7]
cout_each = int(cout / len(self.kernels))
self.convs = nn.ModuleList([nn.Conv2d(cin, cout_each, kernel_size=(1, k), dilation=(1, dilation_factor))
for k in self.kernels])
def forward(self, x):
outs = [conv(x)[..., -self.convs[-1](x).size(3):] for conv in self.convs]
return torch.cat(outs, dim=1)
class GraphConstructor(nn.Module):
def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None):
super().__init__()
self.nnodes, self.k, self.dim, self.alpha, self.device = nnodes, k, dim, alpha, device
self.static_feat = static_feat
if static_feat is not None:
xd = static_feat.shape[1]
self.lin1, self.lin2 = nn.Linear(xd, dim), nn.Linear(xd, dim)
else:
self.emb1 = nn.Embedding(nnodes, dim)
self.emb2 = nn.Embedding(nnodes, dim)
self.lin1, self.lin2 = nn.Linear(dim, dim), nn.Linear(dim, dim)
def forward(self, idx):
if self.static_feat is None:
vec1, vec2 = self.emb1(idx), self.emb2(idx)
else:
vec1 = vec2 = self.static_feat[idx, :]
vec1 = torch.tanh(self.alpha * self.lin1(vec1))
vec2 = torch.tanh(self.alpha * self.lin2(vec2))
a = torch.mm(vec1, vec2.transpose(1, 0)) - torch.mm(vec2, vec1.transpose(1, 0))
adj = F.relu(torch.tanh(self.alpha * a))
mask = torch.zeros(idx.size(0), idx.size(0), device=self.device)
s1, t1 = adj.topk(self.k, 1)
mask.scatter_(1, t1, s1.new_ones(s1.size()))
return adj * mask
class LayerNorm(nn.Module):
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape, self.eps, self.elementwise_affine = tuple(normalized_shape), eps, elementwise_affine
if elementwise_affine:
self.weight = nn.Parameter(torch.Tensor(*normalized_shape))
self.bias = nn.Parameter(torch.Tensor(*normalized_shape))
init.ones_(self.weight);
init.zeros_(self.bias)
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
def forward(self, x, idx):
if self.elementwise_affine:
return F.layer_norm(x, tuple(x.shape[1:]), self.weight[:, idx, :], self.bias[:, idx, :], self.eps)
else:
return F.layer_norm(x, tuple(x.shape[1:]), self.weight, self.bias, self.eps)
def extra_repr(self):
return f'{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
# --- 合并后的模型类,支持 teacher 与 stmlp 两种分支 ---
class STMLP(nn.Module):
def __init__(self, args):
super().__init__()
# 参数从字典中读取
self.adj_mx = get_adj(args)
self.num_nodes = args['num_nodes']
self.feature_dim = args['input_dim']
self.input_window = args['input_window']
self.output_window = args['output_window']
self.output_dim = args['output_dim']
self.device = args['device']
self.gcn_true = args['gcn_true']
self.buildA_true = args['buildA_true']
self.gcn_depth = args['gcn_depth']
self.dropout = args['dropout']
self.subgraph_size = args['subgraph_size']
self.node_dim = args['node_dim']
self.dilation_exponential = args['dilation_exponential']
self.conv_channels = args['conv_channels']
self.residual_channels = args['residual_channels']
self.skip_channels = args['skip_channels']
self.end_channels = args['end_channels']
self.layers = args['layers']
self.propalpha = args['propalpha']
self.tanhalpha = args['tanhalpha']
self.layer_norm_affline = args['layer_norm_affline']
self.model_type = args['model_type'] # 'teacher' 或 'stmlp'
self.idx = torch.arange(self.num_nodes).to(self.device)
self.predefined_A = None if self.adj_mx is None else (torch.tensor(self.adj_mx) - torch.eye(self.num_nodes)).to(
self.device)
self.static_feat = None
# transformer保留原有结构
self.encoder_layer = nn.TransformerEncoderLayer(d_model=12, nhead=4, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=3)
# 构建各层
self.start_conv = nn.Conv2d(self.feature_dim, self.residual_channels, kernel_size=1)
self.gc = GraphConstructor(self.num_nodes, self.subgraph_size, self.node_dim, self.device, alpha=self.tanhalpha,
static_feat=self.static_feat)
# 计算 receptive_field
kernel_size = 7
if self.dilation_exponential > 1:
self.receptive_field = int(
self.output_dim + (kernel_size - 1) * (self.dilation_exponential ** self.layers - 1) / (
self.dilation_exponential - 1))
else:
self.receptive_field = self.layers * (kernel_size - 1) + self.output_dim
self.filter_convs = nn.ModuleList()
self.gate_convs = nn.ModuleList()
self.residual_convs = nn.ModuleList()
self.skip_convs = nn.ModuleList()
self.norm = nn.ModuleList()
self.stu_mlp = nn.ModuleList([nn.Sequential(nn.Linear(c, c), nn.Linear(c, c), nn.Linear(c, c))
for c in [13, 7, 1]])
if self.gcn_true:
self.gconv1 = nn.ModuleList()
self.gconv2 = nn.ModuleList()
new_dilation = 1
for i in range(1):
rf_size_i = int(1 + i * (kernel_size - 1) * (self.dilation_exponential ** self.layers - 1) / (
self.dilation_exponential - 1)) if self.dilation_exponential > 1 else i * self.layers * (
kernel_size - 1) + 1
for j in range(1, self.layers + 1):
rf_size_j = int(rf_size_i + (kernel_size - 1) * (self.dilation_exponential ** j - 1) / (
self.dilation_exponential - 1)) if self.dilation_exponential > 1 else rf_size_i + j * (
kernel_size - 1)
self.filter_convs.append(
DilatedInception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation))
self.gate_convs.append(
DilatedInception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation))
self.residual_convs.append(nn.Conv2d(self.conv_channels, self.residual_channels, kernel_size=1))
k_size = (1, self.input_window - rf_size_j + 1) if self.input_window > self.receptive_field else (
1, self.receptive_field - rf_size_j + 1)
self.skip_convs.append(nn.Conv2d(self.conv_channels, self.skip_channels, kernel_size=k_size))
if self.gcn_true:
self.gconv1.append(MixProp(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout,
self.propalpha))
self.gconv2.append(MixProp(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout,
self.propalpha))
norm_size = (self.residual_channels, self.num_nodes,
self.input_window - rf_size_j + 1) if self.input_window > self.receptive_field else (
self.residual_channels, self.num_nodes, self.receptive_field - rf_size_j + 1)
self.norm.append(LayerNorm(norm_size, elementwise_affine=self.layer_norm_affline))
new_dilation *= self.dilation_exponential
self.end_conv_1 = nn.Conv2d(self.skip_channels, self.end_channels, kernel_size=1, bias=True)
self.end_conv_2 = nn.Conv2d(self.end_channels, self.output_window, kernel_size=1, bias=True)
k0 = (1, self.input_window) if self.input_window > self.receptive_field else (1, self.receptive_field)
self.skip0 = nn.Conv2d(self.feature_dim, self.skip_channels, kernel_size=k0, bias=True)
kE = (1, self.input_window - self.receptive_field + 1) if self.input_window > self.receptive_field else (1, 1)
self.skipE = nn.Conv2d(self.residual_channels, self.skip_channels, kernel_size=kE, bias=True)
# 最后输出分支,根据模型类型选择不同的头
if self.model_type == 'teacher':
self.tt_linear1 = nn.Linear(self.residual_channels, self.input_window)
self.tt_linear2 = nn.Linear(1, 32)
self.ss_linear1 = nn.Linear(self.residual_channels, self.input_window)
self.ss_linear2 = nn.Linear(1, 32)
else: # stmlp
self.out_linear1 = nn.Linear(self.residual_channels, self.input_window)
self.out_linear2 = nn.Linear(1, 32)
def forward(self, source, idx=None):
source = source[..., 0:1]
sout, tout = [], []
inputs = source.transpose(1, 3)
assert inputs.size(3) == self.input_window, 'input sequence length mismatch'
if self.input_window < self.receptive_field:
inputs = F.pad(inputs, (self.receptive_field - self.input_window, 0, 0, 0))
if self.gcn_true:
adp = self.gc(self.idx if idx is None else idx) if self.buildA_true else self.predefined_A
x = self.start_conv(inputs)
skip = self.skip0(F.dropout(inputs, self.dropout, training=self.training))
for i in range(self.layers):
residual = x
filters = torch.tanh(self.filter_convs[i](x))
gate = torch.sigmoid(self.gate_convs[i](x))
x = F.dropout(filters * gate, self.dropout, training=self.training)
tout.append(x)
s = self.skip_convs[i](x)
skip = s + skip
if self.gcn_true:
x = self.gconv1[i](x, adp) + self.gconv2[i](x, adp.transpose(1, 0))
else:
x = self.stu_mlp[i](x)
x = x + residual[:, :, :, -x.size(3):]
x = self.norm[i](x, self.idx if idx is None else idx)
sout.append(x)
skip = self.skipE(x) + skip
x = F.relu(skip)
x = F.relu(self.end_conv_1(x))
x = self.end_conv_2(x)
if self.model_type == 'teacher':
ttout = self.tt_linear2(self.tt_linear1(tout[-1].transpose(1, 3)).transpose(1, 3))
ssout = self.ss_linear2(self.ss_linear1(sout[-1].transpose(1, 3)).transpose(1, 3))
return x, ttout, ssout
else:
x_ = self.out_linear2(self.out_linear1(tout[-1].transpose(1, 3)).transpose(1, 3))
return x, x_, x

View File

@ -13,6 +13,8 @@ from model.STFGNN.STFGNN import STFGNN
from model.STSGCN.STSGCN import STSGCN
from model.STGODE.STGODE import ODEGCN
from model.PDG2SEQ.PDG2Seqb import PDG2Seq
from model.STMLP.STMLP import STMLP
from model.STIDGCN.STIDGCN import STIDGCN
from model.STID.STID import STID
from model.STAEFormer.STAEFormer import STAEformer
from model.EXP.EXP32 import EXP as EXP
@ -34,6 +36,8 @@ def model_selector(model):
case 'STSGCN': return STSGCN(model)
case 'STGODE': return ODEGCN(model)
case 'PDG2SEQ': return PDG2Seq(model)
case 'STMLP': return STMLP(model)
case 'STIDGCN': return STIDGCN(model)
case 'STID': return STID(model)
case 'STAEFormer': return STAEformer(model)
case 'EXP': return EXP(model)

44
run.py
View File

@ -18,6 +18,8 @@ from trainer.trainer_selector import select_trainer
import yaml
def main():
args = parse_args()
@ -32,26 +34,28 @@ def main():
# Initialize model
model = init_model(args['model'], device=args['device'])
# if args['mode'] == "benchmark":
# # 支持计算消耗分析,设置 mode为 benchmark
# import torch.profiler as profiler
# dummy_input = torch.randn((64, 12, args['model']['num_nodes'], 3), device=args['device'])
# min_val = dummy_input.min(dim=-1, keepdim=True)[0]
# max_val = dummy_input.max(dim=-1, keepdim=True)[0]
#
# dummy_input = (dummy_input - min_val) / (max_val - min_val + 1e-6)
# with profiler.profile(
# activities=[
# profiler.ProfilerActivity.CPU,
# profiler.ProfilerActivity.CUDA
# ],
# with_stack=True,
# profile_memory=True,
# record_shapes=True
# ) as prof:
# out = model(dummy_input)
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# return 0
if args['mode'] == "benchmark":
# 支持计算消耗分析,设置 mode为 benchmark
import torch.profiler as profiler
dummy_input = torch.randn((64, 12, args['model']['num_nodes'], 3), device=args['device'])
min_val = dummy_input.min(dim=-1, keepdim=True)[0]
max_val = dummy_input.max(dim=-1, keepdim=True)[0]
dummy_input = (dummy_input - min_val) / (max_val - min_val + 1e-6)
with profiler.profile(
activities=[
profiler.ProfilerActivity.CPU,
profiler.ProfilerActivity.CUDA
],
with_stack=True,
profile_memory=True,
record_shapes=True
) as prof:
out = model(dummy_input)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
return 0
# Load dataset
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(

View File

@ -160,10 +160,6 @@ class Trainer:
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
# 你在这里需要把y_pred和y_true保存下来
# torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1]
# torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1]
for t in range(y_true.shape[1]):
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
args['mae_thresh'], args['mape_thresh'])

View File

@ -161,10 +161,6 @@ class Trainer:
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
# 你在这里需要把y_pred和y_true保存下来
# torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1]
# torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1]
for t in range(y_true.shape[1]):
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
args['mae_thresh'], args['mape_thresh'])

261
trainer/STMLP_Trainer.py Normal file
View File

@ -0,0 +1,261 @@
import math
import os
import sys
import time
import copy
import torch.nn.functional as F
import torch
from torch import nn
from tqdm import tqdm
from lib.logger import get_logger
from lib.loss_function import all_metrics
from model.STMLP.STMLP import STMLP
class Trainer:
def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader,
scaler, args, lr_scheduler=None):
self.model = model
self.loss = loss
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.scaler = scaler
self.args = args['train']
self.lr_scheduler = lr_scheduler
self.train_per_epoch = len(train_loader)
self.val_per_epoch = len(val_loader) if val_loader else 0
# Paths for saving models and logs
self.best_path = os.path.join(self.args['log_dir'], 'best_model.pth')
self.best_test_path = os.path.join(self.args['log_dir'], 'best_test_model.pth')
self.loss_figure_path = os.path.join(self.args['log_dir'], 'loss.png')
self.pretrain_dir = f'./pre-train/{args["model"]["type"]}/{args["data"]["type"]}'
self.pretrain_path = os.path.join(self.pretrain_dir, 'best_model.pth')
self.pretrain_best_path = os.path.join(self.pretrain_dir, 'best_test_model.pth')
# Initialize logger
if not os.path.isdir(self.args['log_dir']) and not self.args['debug']:
os.makedirs(self.args['log_dir'], exist_ok=True)
if not os.path.isdir(self.pretrain_dir) and not self.args['debug']:
os.makedirs(self.pretrain_dir, exist_ok=True)
self.logger = get_logger(self.args['log_dir'], name=self.model.__class__.__name__, debug=self.args['debug'])
self.logger.info(f"Experiment log path in: {self.args['log_dir']}")
if self.args['teacher_stu']:
self.tmodel = self.loadTeacher(args)
else:
self.logger.info(f"当前使用预训练模式,预训练后请移动教师模型到"
f"./pre-train/{args['model']['type']}/{args['data']['type']}/best_model.pth"
f"然后在config中配置train.teacher_stu模式为True开启蒸馏模式")
def _run_epoch(self, epoch, dataloader, mode):
# self.tmodel.eval()
if mode == 'train':
self.model.train()
optimizer_step = True
else:
self.model.eval()
optimizer_step = False
total_loss = 0
epoch_time = time.time()
with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader):
if self.args['teacher_stu']:
label = target[..., :self.args['output_dim']]
output, out_, _ = self.model(data)
gout, tout, sout = self.tmodel(data)
if self.args['real_value']:
output = self.scaler.inverse_transform(output)
loss1 = self.loss(output, label)
scl = self.loss_cls(out_, sout)
kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True).cuda()
gout = F.log_softmax(gout, dim=-1).cuda()
mlp_emb_ = F.log_softmax(output, dim=-1).cuda()
tkloss = kl_loss(mlp_emb_.cuda().float(), gout.cuda().float())
loss = loss1 + 10 * tkloss + 1 * scl
else:
label = target[..., :self.args['output_dim']]
output, out_, _ = self.model(data)
if self.args['real_value']:
output = self.scaler.inverse_transform(output)
loss = self.loss(output, label)
if optimizer_step and self.optimizer is not None:
self.optimizer.zero_grad()
loss.backward()
if self.args['grad_norm']:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
self.logger.info(
f'Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}')
# 更新 tqdm 的进度
pbar.update(1)
pbar.set_postfix(loss=loss.item())
avg_loss = total_loss / len(dataloader)
self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
return avg_loss
def train_epoch(self, epoch):
return self._run_epoch(epoch, self.train_loader, 'train')
def val_epoch(self, epoch):
return self._run_epoch(epoch, self.val_loader or self.test_loader, 'val')
def test_epoch(self, epoch):
return self._run_epoch(epoch, self.test_loader, 'test')
def train(self):
best_model, best_test_model = None, None
best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0
self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch)
val_epoch_loss = self.val_epoch(epoch)
test_epoch_loss = self.test_epoch(epoch)
if train_epoch_loss > 1e6:
self.logger.warning('Gradient explosion detected. Ending...')
break
if val_epoch_loss < best_loss:
best_loss = val_epoch_loss
not_improved_count = 0
best_model = copy.deepcopy(self.model.state_dict())
torch.save(best_model, self.best_path)
torch.save(best_model, self.pretrain_path)
self.logger.info('Best validation model saved!')
else:
not_improved_count += 1
if self.args['early_stop'] and not_improved_count == self.args['early_stop_patience']:
self.logger.info(
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
break
if test_epoch_loss < best_test_loss:
best_test_loss = test_epoch_loss
best_test_model = copy.deepcopy(self.model.state_dict())
torch.save(best_test_model, self.best_test_path)
torch.save(best_model, self.pretrain_best_path)
if not self.args['debug']:
torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model)
self.logger.info("Testing on best validation model")
self.test(self.model, self.args, self.test_loader, self.scaler, self.logger)
self.model.load_state_dict(best_test_model)
self.logger.info("Testing on best test model")
self.test(self.model, self.args, self.test_loader, self.scaler, self.logger)
def loadTeacher(self, args):
model_path = f'./pre-train/{args["model"]["type"]}/{args["data"]["type"]}/best_model.pth'
try:
# 尝试加载教师模型权重
state_dict = torch.load(model_path)
self.logger.info(f"成功加载教师模型权重: {model_path}")
# 初始化并返回教师模型
args['model']['model_type'] = 'teacher'
tmodel = STMLP(args['model'])
tmodel = tmodel.to(args['device'])
tmodel.load_state_dict(state_dict, strict=False)
return tmodel
except FileNotFoundError:
# 如果找不到权重文件,记录日志并修改 args
self.logger.error(
f"未找到教师模型权重文件: {model_path}。切换到预训练模式训练老师权重。\n"
f"在预训练完成后,再次启动模型则为蒸馏模式")
self.args['teacher_stu'] = False
return None
def loss_cls(self, x1, x2):
temperature = 0.05
x1 = F.normalize(x1, p=2, dim=-1)
x2 = F.normalize(x2, p=2, dim=-1)
weight = F.cosine_similarity(x1, x2, dim=-1)
batch_size = x1.size()[0]
# neg score
out = torch.cat([x1, x2], dim=0)
neg = torch.exp(torch.matmul(out, out.transpose(2, 3).contiguous()) / temperature)
pos = torch.exp(torch.sum(x1 * x2, dim=-1) * weight / temperature)
# pos = torch.exp(torch.sum(x1 * x2, dim=-1) / temperature)
pos = torch.cat([pos, pos], dim=0).sum(dim=1)
Ng = neg.sum(dim=-1).sum(dim=1)
loss = (- torch.log(pos / (pos + Ng))).mean()
return loss
@staticmethod
def test(model, args, data_loader, scaler, logger, path=None):
if path:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict'])
model.to(args['device'])
model.eval()
y_pred, y_true = [], []
with torch.no_grad():
for data, target in data_loader:
label = target[..., :args['output_dim']]
output, _, _ = model(data)
y_pred.append(output)
y_true.append(label)
if args['real_value']:
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
else:
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
# 你在这里需要把y_pred和y_true保存下来
# torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1]
# torch.save(y_true, "./test/PEMSD8/y_true.pt") # [3566,12,170,1]
for t in range(y_true.shape[1]):
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
args['mae_thresh'], args['mape_thresh'])
logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
mae, rmse, mape = all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
@staticmethod
def _compute_sampling_threshold(global_step, k):
return k / (k + math.exp(global_step / k))

View File

@ -2,6 +2,7 @@ from trainer.Trainer import Trainer
from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer
from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer
from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer
from trainer.STMLP_Trainer import Trainer as STMLP_Trainer
from trainer.E32Trainer import Trainer as EXP_Trainer
@ -14,6 +15,8 @@ def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader
lr_scheduler)
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler)
case 'STMLP': return STMLP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
lr_scheduler)
case 'EXP': return EXP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler)
case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],

View File

@ -299,7 +299,7 @@ def read_data(args):
'pems03': ['PEMS03/pems03.npz', 'PEMS03/distance.csv'],
'pems04': ['PEMS04/pems04.npz', 'PEMS04/distance.csv'],
'pems07': ['PEMS07/pems07.npz', 'PEMS07/distance.csv'],
'pems08': ['PEMS08/pems08.npz', 'PEMS08/distance.csv'],
'pems08': ['PEMSD8/pems08.npz', 'PEMSD8/distance.csv'],
'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'],
'pemsD7M': ['PeMSD7M/PeMSD7M.npz', 'PeMSD7M/distance.csv'],
'pemsD7L': ['PeMSD7L/PeMSD7L.npz', 'PeMSD7L/distance.csv']