新增了模型蒸馏STMLP

现在Trainer每次epoch完后都会保存模型checkpoint
其中STMLP会自动教师模型到pre-train
根据教师模型的存在情况启动/预训练or蒸馏模式
This commit is contained in:
czzhangheng 2025-04-07 17:05:59 +08:00
parent 229b6320b9
commit bc9a2667c2
15 changed files with 844 additions and 18 deletions

1
.gitignore vendored
View File

@ -7,6 +7,7 @@ experiments/
*.pkl *.pkl
data/ data/
pretrain/ pretrain/
pre-train/
# ---> Python # ---> Python
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files

66
config/STMLP/PEMSD3.yaml Normal file
View File

@ -0,0 +1,66 @@
data:
num_nodes: 358
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 1
output_dim: 1
input_window: 12
output_window: 12
gcn_true: true
buildA_true: true
gcn_depth: 2
dropout: 0.3
subgraph_size: 20
node_dim: 40
dilation_exponential: 1
conv_channels: 32
residual_channels: 32
skip_channels: 64
end_channels: 128
layers: 3
propalpha: 0.05
tanhalpha: 3
layer_norm_affline: true
use_curriculum_learning: true
step_size1: 2500
task_level: 0
num_split: 1
step_size2: 100
model_type: stmlp
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
teacher_stu: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 2000
plot: False

67
config/STMLP/PEMSD4.yaml Normal file
View File

@ -0,0 +1,67 @@
data:
num_nodes: 307
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 1
output_dim: 1
input_window: 12
output_window: 12
gcn_true: true
buildA_true: true
gcn_depth: 2
dropout: 0.3
subgraph_size: 20
node_dim: 40
dilation_exponential: 1
conv_channels: 32
residual_channels: 32
skip_channels: 64
end_channels: 128
layers: 3
propalpha: 0.05
tanhalpha: 3
layer_norm_affline: true
use_curriculum_learning: true
step_size1: 2500
task_level: 0
num_split: 1
step_size2: 100
model_type: stmlp
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
teacher: True
teacher_stu: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 2000
plot: False

66
config/STMLP/PEMSD7.yaml Normal file
View File

@ -0,0 +1,66 @@
data:
num_nodes: 883
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 1
output_dim: 1
input_window: 12
output_window: 12
gcn_true: true
buildA_true: true
gcn_depth: 2
dropout: 0.3
subgraph_size: 20
node_dim: 40
dilation_exponential: 1
conv_channels: 32
residual_channels: 32
skip_channels: 64
end_channels: 128
layers: 3
propalpha: 0.05
tanhalpha: 3
layer_norm_affline: true
use_curriculum_learning: true
step_size1: 2500
task_level: 0
num_split: 1
step_size2: 100
model_type: stmlp
train:
loss_func: mae
seed: 10
batch_size: 16
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
teacher_stu: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 2000
plot: False

66
config/STMLP/PEMSD8.yaml Normal file
View File

@ -0,0 +1,66 @@
data:
num_nodes: 170
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 288
days_per_week: 7
model:
input_dim: 1
output_dim: 1
input_window: 12
output_window: 12
gcn_true: true
buildA_true: true
gcn_depth: 2
dropout: 0.3
subgraph_size: 20
node_dim: 40
dilation_exponential: 1
conv_channels: 32
residual_channels: 32
skip_channels: 64
end_channels: 128
layers: 3
propalpha: 0.05
tanhalpha: 3
layer_norm_affline: true
use_curriculum_learning: true
step_size1: 2500
task_level: 0
num_split: 1
step_size2: 100
model_type: stmlp
train:
loss_func: mae
seed: 10
batch_size: 64
epochs: 300
lr_init: 0.003
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
teacher_stu: True
test:
mae_thresh: null
mape_thresh: 0.0
log:
log_step: 2000
plot: False

View File

@ -121,7 +121,7 @@ def download_kaggle_data(current_dir):
如果目标文件夹已存在会覆盖冲突的文件 如果目标文件夹已存在会覆盖冲突的文件
""" """
try: try:
print("正在下载 KaggleHub 数据集...") print("正在下载 PEMS 数据集...")
path = kagglehub.dataset_download("elmahy/pems-dataset") path = kagglehub.dataset_download("elmahy/pems-dataset")
# print("Path to KaggleHub dataset files:", path) # print("Path to KaggleHub dataset files:", path)

307
model/STMLP/STMLP.py Normal file
View File

@ -0,0 +1,307 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from data.get_adj import get_adj
import numbers
# --- 基础算子 ---
class NConv(nn.Module):
def forward(self, x, adj):
return torch.einsum('ncwl,vw->ncvl', (x, adj)).contiguous()
class DyNconv(nn.Module):
def forward(self, x, adj):
return torch.einsum('ncvl,nvwl->ncwl', (x, adj)).contiguous()
class Linear(nn.Module):
def __init__(self, c_in, c_out, bias=True):
super().__init__()
self.mlp = nn.Conv2d(c_in, c_out, kernel_size=1, bias=bias)
def forward(self, x):
return self.mlp(x)
class Prop(nn.Module):
def __init__(self, c_in, c_out, gdep, dropout, alpha):
super().__init__()
self.nconv = NConv()
self.mlp = Linear(c_in, c_out)
self.gdep, self.dropout, self.alpha = gdep, dropout, alpha
def forward(self, x, adj):
adj = adj + torch.eye(adj.size(0), device=x.device)
d = adj.sum(1)
a = adj / d.view(-1, 1)
h = x
for _ in range(self.gdep):
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a)
return self.mlp(h)
class MixProp(nn.Module):
def __init__(self, c_in, c_out, gdep, dropout, alpha):
super().__init__()
self.nconv = NConv()
self.mlp = Linear((gdep + 1) * c_in, c_out)
self.gdep, self.dropout, self.alpha = gdep, dropout, alpha
def forward(self, x, adj):
adj = adj + torch.eye(adj.size(0), device=x.device)
d = adj.sum(1)
a = adj / d.view(-1, 1)
out = [x]
h = x
for _ in range(self.gdep):
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a)
out.append(h)
return self.mlp(torch.cat(out, dim=1))
class DyMixprop(nn.Module):
def __init__(self, c_in, c_out, gdep, dropout, alpha):
super().__init__()
self.nconv = DyNconv()
self.mlp1 = Linear((gdep + 1) * c_in, c_out)
self.mlp2 = Linear((gdep + 1) * c_in, c_out)
self.gdep, self.dropout, self.alpha = gdep, dropout, alpha
self.lin1, self.lin2 = Linear(c_in, c_in), Linear(c_in, c_in)
def forward(self, x):
x1 = torch.tanh(self.lin1(x))
x2 = torch.tanh(self.lin2(x))
adj = self.nconv(x1.transpose(2, 1), x2)
adj0 = torch.softmax(adj, dim=2)
adj1 = torch.softmax(adj.transpose(2, 1), dim=2)
# 两条分支
out1, out2 = [x], [x]
h = x
for _ in range(self.gdep):
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj0)
out1.append(h)
h = x
for _ in range(self.gdep):
h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1)
out2.append(h)
return self.mlp1(torch.cat(out1, dim=1)) + self.mlp2(torch.cat(out2, dim=1))
class DilatedInception(nn.Module):
def __init__(self, cin, cout, dilation_factor=2):
super().__init__()
self.kernels = [2, 3, 6, 7]
cout_each = int(cout / len(self.kernels))
self.convs = nn.ModuleList([nn.Conv2d(cin, cout_each, kernel_size=(1, k), dilation=(1, dilation_factor))
for k in self.kernels])
def forward(self, x):
outs = [conv(x)[..., -self.convs[-1](x).size(3):] for conv in self.convs]
return torch.cat(outs, dim=1)
class GraphConstructor(nn.Module):
def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None):
super().__init__()
self.nnodes, self.k, self.dim, self.alpha, self.device = nnodes, k, dim, alpha, device
self.static_feat = static_feat
if static_feat is not None:
xd = static_feat.shape[1]
self.lin1, self.lin2 = nn.Linear(xd, dim), nn.Linear(xd, dim)
else:
self.emb1 = nn.Embedding(nnodes, dim)
self.emb2 = nn.Embedding(nnodes, dim)
self.lin1, self.lin2 = nn.Linear(dim, dim), nn.Linear(dim, dim)
def forward(self, idx):
if self.static_feat is None:
vec1, vec2 = self.emb1(idx), self.emb2(idx)
else:
vec1 = vec2 = self.static_feat[idx, :]
vec1 = torch.tanh(self.alpha * self.lin1(vec1))
vec2 = torch.tanh(self.alpha * self.lin2(vec2))
a = torch.mm(vec1, vec2.transpose(1, 0)) - torch.mm(vec2, vec1.transpose(1, 0))
adj = F.relu(torch.tanh(self.alpha * a))
mask = torch.zeros(idx.size(0), idx.size(0), device=self.device)
s1, t1 = adj.topk(self.k, 1)
mask.scatter_(1, t1, s1.new_ones(s1.size()))
return adj * mask
class LayerNorm(nn.Module):
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape, self.eps, self.elementwise_affine = tuple(normalized_shape), eps, elementwise_affine
if elementwise_affine:
self.weight = nn.Parameter(torch.Tensor(*normalized_shape))
self.bias = nn.Parameter(torch.Tensor(*normalized_shape))
init.ones_(self.weight);
init.zeros_(self.bias)
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
def forward(self, x, idx):
if self.elementwise_affine:
return F.layer_norm(x, tuple(x.shape[1:]), self.weight[:, idx, :], self.bias[:, idx, :], self.eps)
else:
return F.layer_norm(x, tuple(x.shape[1:]), self.weight, self.bias, self.eps)
def extra_repr(self):
return f'{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
# --- 合并后的模型类,支持 teacher 与 stmlp 两种分支 ---
class STMLP(nn.Module):
def __init__(self, args):
super().__init__()
# 参数从字典中读取
self.adj_mx = get_adj(args)
self.num_nodes = args['num_nodes']
self.feature_dim = args['input_dim']
self.input_window = args['input_window']
self.output_window = args['output_window']
self.output_dim = args['output_dim']
self.device = args['device']
self.gcn_true = args['gcn_true']
self.buildA_true = args['buildA_true']
self.gcn_depth = args['gcn_depth']
self.dropout = args['dropout']
self.subgraph_size = args['subgraph_size']
self.node_dim = args['node_dim']
self.dilation_exponential = args['dilation_exponential']
self.conv_channels = args['conv_channels']
self.residual_channels = args['residual_channels']
self.skip_channels = args['skip_channels']
self.end_channels = args['end_channels']
self.layers = args['layers']
self.propalpha = args['propalpha']
self.tanhalpha = args['tanhalpha']
self.layer_norm_affline = args['layer_norm_affline']
self.model_type = args['model_type'] # 'teacher' 或 'stmlp'
self.idx = torch.arange(self.num_nodes).to(self.device)
self.predefined_A = None if self.adj_mx is None else (torch.tensor(self.adj_mx) - torch.eye(self.num_nodes)).to(
self.device)
self.static_feat = None
# transformer保留原有结构
self.encoder_layer = nn.TransformerEncoderLayer(d_model=12, nhead=4, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=3)
# 构建各层
self.start_conv = nn.Conv2d(self.feature_dim, self.residual_channels, kernel_size=1)
self.gc = GraphConstructor(self.num_nodes, self.subgraph_size, self.node_dim, self.device, alpha=self.tanhalpha,
static_feat=self.static_feat)
# 计算 receptive_field
kernel_size = 7
if self.dilation_exponential > 1:
self.receptive_field = int(
self.output_dim + (kernel_size - 1) * (self.dilation_exponential ** self.layers - 1) / (
self.dilation_exponential - 1))
else:
self.receptive_field = self.layers * (kernel_size - 1) + self.output_dim
self.filter_convs = nn.ModuleList()
self.gate_convs = nn.ModuleList()
self.residual_convs = nn.ModuleList()
self.skip_convs = nn.ModuleList()
self.norm = nn.ModuleList()
self.stu_mlp = nn.ModuleList([nn.Sequential(nn.Linear(c, c), nn.Linear(c, c), nn.Linear(c, c))
for c in [13, 7, 1]])
if self.gcn_true:
self.gconv1 = nn.ModuleList()
self.gconv2 = nn.ModuleList()
new_dilation = 1
for i in range(1):
rf_size_i = int(1 + i * (kernel_size - 1) * (self.dilation_exponential ** self.layers - 1) / (
self.dilation_exponential - 1)) if self.dilation_exponential > 1 else i * self.layers * (
kernel_size - 1) + 1
for j in range(1, self.layers + 1):
rf_size_j = int(rf_size_i + (kernel_size - 1) * (self.dilation_exponential ** j - 1) / (
self.dilation_exponential - 1)) if self.dilation_exponential > 1 else rf_size_i + j * (
kernel_size - 1)
self.filter_convs.append(
DilatedInception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation))
self.gate_convs.append(
DilatedInception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation))
self.residual_convs.append(nn.Conv2d(self.conv_channels, self.residual_channels, kernel_size=1))
k_size = (1, self.input_window - rf_size_j + 1) if self.input_window > self.receptive_field else (
1, self.receptive_field - rf_size_j + 1)
self.skip_convs.append(nn.Conv2d(self.conv_channels, self.skip_channels, kernel_size=k_size))
if self.gcn_true:
self.gconv1.append(MixProp(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout,
self.propalpha))
self.gconv2.append(MixProp(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout,
self.propalpha))
norm_size = (self.residual_channels, self.num_nodes,
self.input_window - rf_size_j + 1) if self.input_window > self.receptive_field else (
self.residual_channels, self.num_nodes, self.receptive_field - rf_size_j + 1)
self.norm.append(LayerNorm(norm_size, elementwise_affine=self.layer_norm_affline))
new_dilation *= self.dilation_exponential
self.end_conv_1 = nn.Conv2d(self.skip_channels, self.end_channels, kernel_size=1, bias=True)
self.end_conv_2 = nn.Conv2d(self.end_channels, self.output_window, kernel_size=1, bias=True)
k0 = (1, self.input_window) if self.input_window > self.receptive_field else (1, self.receptive_field)
self.skip0 = nn.Conv2d(self.feature_dim, self.skip_channels, kernel_size=k0, bias=True)
kE = (1, self.input_window - self.receptive_field + 1) if self.input_window > self.receptive_field else (1, 1)
self.skipE = nn.Conv2d(self.residual_channels, self.skip_channels, kernel_size=kE, bias=True)
# 最后输出分支,根据模型类型选择不同的头
if self.model_type == 'teacher':
self.tt_linear1 = nn.Linear(self.residual_channels, self.input_window)
self.tt_linear2 = nn.Linear(1, 32)
self.ss_linear1 = nn.Linear(self.residual_channels, self.input_window)
self.ss_linear2 = nn.Linear(1, 32)
else: # stmlp
self.out_linear1 = nn.Linear(self.residual_channels, self.input_window)
self.out_linear2 = nn.Linear(1, 32)
def forward(self, source, idx=None):
source = source[..., 0:1]
sout, tout = [], []
inputs = source.transpose(1, 3)
assert inputs.size(3) == self.input_window, 'input sequence length mismatch'
if self.input_window < self.receptive_field:
inputs = F.pad(inputs, (self.receptive_field - self.input_window, 0, 0, 0))
if self.gcn_true:
adp = self.gc(self.idx if idx is None else idx) if self.buildA_true else self.predefined_A
x = self.start_conv(inputs)
skip = self.skip0(F.dropout(inputs, self.dropout, training=self.training))
for i in range(self.layers):
residual = x
filters = torch.tanh(self.filter_convs[i](x))
gate = torch.sigmoid(self.gate_convs[i](x))
x = F.dropout(filters * gate, self.dropout, training=self.training)
tout.append(x)
s = self.skip_convs[i](x)
skip = s + skip
if self.gcn_true:
x = self.gconv1[i](x, adp) + self.gconv2[i](x, adp.transpose(1, 0))
else:
x = self.stu_mlp[i](x)
x = x + residual[:, :, :, -x.size(3):]
x = self.norm[i](x, self.idx if idx is None else idx)
sout.append(x)
skip = self.skipE(x) + skip
x = F.relu(skip)
x = F.relu(self.end_conv_1(x))
x = self.end_conv_2(x)
if self.model_type == 'teacher':
ttout = self.tt_linear2(self.tt_linear1(tout[-1].transpose(1, 3)).transpose(1, 3))
ssout = self.ss_linear2(self.ss_linear1(sout[-1].transpose(1, 3)).transpose(1, 3))
return x, ttout, ssout
else:
x_ = self.out_linear2(self.out_linear1(tout[-1].transpose(1, 3)).transpose(1, 3))
return x, x_, x

View File

@ -13,8 +13,7 @@ from model.STFGNN.STFGNN import STFGNN
from model.STSGCN.STSGCN import STSGCN from model.STSGCN.STSGCN import STSGCN
from model.STGODE.STGODE import ODEGCN from model.STGODE.STGODE import ODEGCN
from model.PDG2SEQ.PDG2Seq import PDG2Seq from model.PDG2SEQ.PDG2Seq import PDG2Seq
from model.EXP.EXP import EXP from model.STMLP.STMLP import STMLP
from model.EXPB.EXP_b import EXPB
def model_selector(model): def model_selector(model):
match model['type']: match model['type']:
@ -33,6 +32,5 @@ def model_selector(model):
case 'STSGCN': return STSGCN(model) case 'STSGCN': return STSGCN(model)
case 'STGODE': return ODEGCN(model) case 'STGODE': return ODEGCN(model)
case 'PDG2SEQ': return PDG2Seq(model) case 'PDG2SEQ': return PDG2Seq(model)
case 'EXP': return EXP(model) case 'STMLP': return STMLP(model)
case 'EXPB': return EXPB(model)

3
run.py
View File

@ -17,9 +17,6 @@ from dataloader.loader_selector import get_dataloader
from trainer.trainer_selector import select_trainer from trainer.trainer_selector import select_trainer
import yaml import yaml
def main(): def main():
args = parse_args() args = parse_args()

View File

@ -160,10 +160,6 @@ class Trainer:
y_pred = torch.cat(y_pred, dim=0) y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0) y_true = torch.cat(y_true, dim=0)
# 你在这里需要把y_pred和y_true保存下来
# torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1]
# torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1]
for t in range(y_true.shape[1]): for t in range(y_true.shape[1]):
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
args['mae_thresh'], args['mape_thresh']) args['mae_thresh'], args['mape_thresh'])

View File

@ -161,10 +161,6 @@ class Trainer:
y_pred = torch.cat(y_pred, dim=0) y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0) y_true = torch.cat(y_true, dim=0)
# 你在这里需要把y_pred和y_true保存下来
# torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1]
# torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1]
for t in range(y_true.shape[1]): for t in range(y_true.shape[1]):
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
args['mae_thresh'], args['mape_thresh']) args['mae_thresh'], args['mape_thresh'])

261
trainer/STMLP_Trainer.py Normal file
View File

@ -0,0 +1,261 @@
import math
import os
import sys
import time
import copy
import torch.nn.functional as F
import torch
from torch import nn
from tqdm import tqdm
from lib.logger import get_logger
from lib.loss_function import all_metrics
from model.STMLP.STMLP import STMLP
class Trainer:
def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader,
scaler, args, lr_scheduler=None):
self.model = model
self.loss = loss
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.scaler = scaler
self.args = args['train']
self.lr_scheduler = lr_scheduler
self.train_per_epoch = len(train_loader)
self.val_per_epoch = len(val_loader) if val_loader else 0
# Paths for saving models and logs
self.best_path = os.path.join(self.args['log_dir'], 'best_model.pth')
self.best_test_path = os.path.join(self.args['log_dir'], 'best_test_model.pth')
self.loss_figure_path = os.path.join(self.args['log_dir'], 'loss.png')
self.pretrain_dir = f'./pre-train/{args["model"]["type"]}/{args["data"]["type"]}'
self.pretrain_path = os.path.join(self.pretrain_dir, 'best_model.pth')
self.pretrain_best_path = os.path.join(self.pretrain_dir, 'best_test_model.pth')
# Initialize logger
if not os.path.isdir(self.args['log_dir']) and not self.args['debug']:
os.makedirs(self.args['log_dir'], exist_ok=True)
if not os.path.isdir(self.pretrain_dir) and not self.args['debug']:
os.makedirs(self.pretrain_dir, exist_ok=True)
self.logger = get_logger(self.args['log_dir'], name=self.model.__class__.__name__, debug=self.args['debug'])
self.logger.info(f"Experiment log path in: {self.args['log_dir']}")
if self.args['teacher_stu']:
self.tmodel = self.loadTeacher(args)
else:
self.logger.info(f"当前使用预训练模式,预训练后请移动教师模型到"
f"./pre-train/{args['model']['type']}/{args['data']['type']}/best_model.pth"
f"然后在config中配置train.teacher_stu模式为True开启蒸馏模式")
def _run_epoch(self, epoch, dataloader, mode):
# self.tmodel.eval()
if mode == 'train':
self.model.train()
optimizer_step = True
else:
self.model.eval()
optimizer_step = False
total_loss = 0
epoch_time = time.time()
with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader):
if self.args['teacher_stu']:
label = target[..., :self.args['output_dim']]
output, out_, _ = self.model(data)
gout, tout, sout = self.tmodel(data)
if self.args['real_value']:
output = self.scaler.inverse_transform(output)
loss1 = self.loss(output, label)
scl = self.loss_cls(out_, sout)
kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True).cuda()
gout = F.log_softmax(gout, dim=-1).cuda()
mlp_emb_ = F.log_softmax(output, dim=-1).cuda()
tkloss = kl_loss(mlp_emb_.cuda().float(), gout.cuda().float())
loss = loss1 + 10 * tkloss + 1 * scl
else:
label = target[..., :self.args['output_dim']]
output, out_, _ = self.model(data)
if self.args['real_value']:
output = self.scaler.inverse_transform(output)
loss = self.loss(output, label)
if optimizer_step and self.optimizer is not None:
self.optimizer.zero_grad()
loss.backward()
if self.args['grad_norm']:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
self.logger.info(
f'Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}')
# 更新 tqdm 的进度
pbar.update(1)
pbar.set_postfix(loss=loss.item())
avg_loss = total_loss / len(dataloader)
self.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
return avg_loss
def train_epoch(self, epoch):
return self._run_epoch(epoch, self.train_loader, 'train')
def val_epoch(self, epoch):
return self._run_epoch(epoch, self.val_loader or self.test_loader, 'val')
def test_epoch(self, epoch):
return self._run_epoch(epoch, self.test_loader, 'test')
def train(self):
best_model, best_test_model = None, None
best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0
self.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch)
val_epoch_loss = self.val_epoch(epoch)
test_epoch_loss = self.test_epoch(epoch)
if train_epoch_loss > 1e6:
self.logger.warning('Gradient explosion detected. Ending...')
break
if val_epoch_loss < best_loss:
best_loss = val_epoch_loss
not_improved_count = 0
best_model = copy.deepcopy(self.model.state_dict())
torch.save(best_model, self.best_path)
torch.save(best_model, self.pretrain_path)
self.logger.info('Best validation model saved!')
else:
not_improved_count += 1
if self.args['early_stop'] and not_improved_count == self.args['early_stop_patience']:
self.logger.info(
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
break
if test_epoch_loss < best_test_loss:
best_test_loss = test_epoch_loss
best_test_model = copy.deepcopy(self.model.state_dict())
torch.save(best_test_model, self.best_test_path)
torch.save(best_model, self.pretrain_best_path)
if not self.args['debug']:
torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model)
self.logger.info("Testing on best validation model")
self.test(self.model, self.args, self.test_loader, self.scaler, self.logger)
self.model.load_state_dict(best_test_model)
self.logger.info("Testing on best test model")
self.test(self.model, self.args, self.test_loader, self.scaler, self.logger)
def loadTeacher(self, args):
model_path = f'./pre-train/{args["model"]["type"]}/{args["data"]["type"]}/best_model.pth'
try:
# 尝试加载教师模型权重
state_dict = torch.load(model_path)
self.logger.info(f"成功加载教师模型权重: {model_path}")
# 初始化并返回教师模型
args['model']['model_type'] = 'teacher'
tmodel = STMLP(args['model'])
tmodel = tmodel.to(args['device'])
tmodel.load_state_dict(state_dict, strict=False)
return tmodel
except FileNotFoundError:
# 如果找不到权重文件,记录日志并修改 args
self.logger.error(
f"未找到教师模型权重文件: {model_path}。切换到预训练模式训练老师权重。\n"
f"在预训练完成后,再次启动模型则为蒸馏模式")
self.args['teacher_stu'] = False
return None
def loss_cls(self, x1, x2):
temperature = 0.05
x1 = F.normalize(x1, p=2, dim=-1)
x2 = F.normalize(x2, p=2, dim=-1)
weight = F.cosine_similarity(x1, x2, dim=-1)
batch_size = x1.size()[0]
# neg score
out = torch.cat([x1, x2], dim=0)
neg = torch.exp(torch.matmul(out, out.transpose(2, 3).contiguous()) / temperature)
pos = torch.exp(torch.sum(x1 * x2, dim=-1) * weight / temperature)
# pos = torch.exp(torch.sum(x1 * x2, dim=-1) / temperature)
pos = torch.cat([pos, pos], dim=0).sum(dim=1)
Ng = neg.sum(dim=-1).sum(dim=1)
loss = (- torch.log(pos / (pos + Ng))).mean()
return loss
@staticmethod
def test(model, args, data_loader, scaler, logger, path=None):
if path:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict'])
model.to(args['device'])
model.eval()
y_pred, y_true = [], []
with torch.no_grad():
for data, target in data_loader:
label = target[..., :args['output_dim']]
output, _, _ = model(data)
y_pred.append(output)
y_true.append(label)
if args['real_value']:
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
else:
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
# 你在这里需要把y_pred和y_true保存下来
# torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1]
# torch.save(y_true, "./test/PEMSD8/y_true.pt") # [3566,12,170,1]
for t in range(y_true.shape[1]):
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
args['mae_thresh'], args['mape_thresh'])
logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
mae, rmse, mape = all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
@staticmethod
def _compute_sampling_threshold(global_step, k):
return k / (k + math.exp(global_step / k))

View File

@ -107,6 +107,7 @@ class Trainer:
best_loss = val_epoch_loss best_loss = val_epoch_loss
not_improved_count = 0 not_improved_count = 0
best_model = copy.deepcopy(self.model.state_dict()) best_model = copy.deepcopy(self.model.state_dict())
torch.save(best_model, self.best_path)
self.logger.info('Best validation model saved!') self.logger.info('Best validation model saved!')
else: else:
not_improved_count += 1 not_improved_count += 1
@ -118,6 +119,7 @@ class Trainer:
if test_epoch_loss < best_test_loss: if test_epoch_loss < best_test_loss:
best_test_loss = test_epoch_loss best_test_loss = test_epoch_loss
torch.save(best_test_model, self.best_test_path)
best_test_model = copy.deepcopy(self.model.state_dict()) best_test_model = copy.deepcopy(self.model.state_dict())
if not self.args['debug']: if not self.args['debug']:
@ -161,7 +163,7 @@ class Trainer:
# 你在这里需要把y_pred和y_true保存下来 # 你在这里需要把y_pred和y_true保存下来
# torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1] # torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1]
# torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1] # torch.save(y_true, "./test/PEMSD8/y_true.pt") # [3566,12,170,1]
for t in range(y_true.shape[1]): for t in range(y_true.shape[1]):
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...],

View File

@ -2,6 +2,7 @@ from trainer.Trainer import Trainer
from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer
from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer
from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer
from trainer.STMLP_Trainer import Trainer as STMLP_Trainer
def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
@ -13,5 +14,7 @@ def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader
lr_scheduler) lr_scheduler)
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler) lr_scheduler)
case 'STMLP': return STMLP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
lr_scheduler)
case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler) lr_scheduler)

View File

@ -299,7 +299,7 @@ def read_data(args):
'pems03': ['PEMS03/pems03.npz', 'PEMS03/distance.csv'], 'pems03': ['PEMS03/pems03.npz', 'PEMS03/distance.csv'],
'pems04': ['PEMS04/pems04.npz', 'PEMS04/distance.csv'], 'pems04': ['PEMS04/pems04.npz', 'PEMS04/distance.csv'],
'pems07': ['PEMS07/pems07.npz', 'PEMS07/distance.csv'], 'pems07': ['PEMS07/pems07.npz', 'PEMS07/distance.csv'],
'pems08': ['PEMS08/pems08.npz', 'PEMS08/distance.csv'], 'pems08': ['PEMSD8/pems08.npz', 'PEMSD8/distance.csv'],
'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'], 'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'],
'pemsD7M': ['PeMSD7M/PeMSD7M.npz', 'PeMSD7M/distance.csv'], 'pemsD7M': ['PeMSD7M/PeMSD7M.npz', 'PeMSD7M/distance.csv'],
'pemsD7L': ['PeMSD7L/PeMSD7L.npz', 'PeMSD7L/distance.csv'] 'pemsD7L': ['PeMSD7L/PeMSD7L.npz', 'PeMSD7L/distance.csv']