STDEN工程化到当前项目

This commit is contained in:
harry.zhang 2025-09-01 11:52:33 +08:00
parent 568fff7e99
commit ee955e9481
23 changed files with 2435 additions and 3 deletions

3
.gitignore vendored
View File

@ -15,7 +15,6 @@ dist/
downloads/ downloads/
eggs/ eggs/
.eggs/ .eggs/
lib/
lib64/ lib64/
parts/ parts/
sdist/ sdist/
@ -160,3 +159,5 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
STDEN/

136
README.md
View File

@ -1,3 +1,135 @@
# Project-I # STDEN项目
Secret Projct 时空扩散方程网络Spatio-Temporal Diffusion Equation Network项目用于时空序列预测任务。
## 项目结构
```
Project-I/
├── run.py # 主运行文件
├── configs/ # 配置文件目录
│ ├── stde_gt.yaml # STDE_GT模型配置
│ ├── stde_wrs.yaml # STDE_WRS模型配置
│ └── stde_zgc.yaml # STDE_ZGC模型配置
├── dataloader/ # 数据加载器模块
│ ├── __init__.py
│ └── stden_dataloader.py
├── trainer/ # 训练器模块
│ ├── __init__.py
│ └── stden_trainer.py
├── model/ # 模型模块
│ ├── __init__.py
│ ├── stden_model.py
│ ├── stden_supervisor.py
│ ├── diffeq_solver.py
│ └── ode_func.py
├── lib/ # 工具库
│ ├── __init__.py
│ ├── logger.py
│ ├── utils.py
│ └── metrics.py
├── requirements.txt # 项目依赖
└── README.md # 项目说明
```
## 快速开始
### 1. 安装依赖
```bash
pip install -r requirements.txt
```
### 2. 训练模型
```bash
# 训练STDE_GT模型
python run.py --model_name stde_gt --mode train
# 训练STDE_WRS模型
python run.py --model_name stde_wrs --mode train
# 训练STDE_ZGC模型
python run.py --model_name stde_zgc --mode train
```
### 3. 评估模型
```bash
# 评估STDE_GT模型
python run.py --model_name stde_gt --mode eval --save_pred
# 评估STDE_WRS模型
python run.py --model_name stde_wrs --mode eval --save_pred
# 评估STDE_ZGC模型
python run.py --model_name stde_zgc --mode eval --save_pred
```
## 配置说明
项目使用YAML格式的配置文件主要包含三个部分
### 数据配置 (data)
- `dataset_dir`: 数据集目录路径
- `batch_size`: 训练批处理大小
- `val_batch_size`: 验证批处理大小
- `graph_pkl_filename`: 传感器图邻接矩阵文件
### 模型配置 (model)
- `seq_len`: 输入序列长度
- `horizon`: 预测时间步数
- `input_dim`: 输入特征维度
- `output_dim`: 输出特征维度
- `latent_dim`: 潜在空间维度
- `n_traj_samples`: 轨迹采样数量
- `ode_method`: ODE求解方法
- `rnn_units`: RNN隐藏单元数量
- `gcn_step`: 图卷积步数
### 训练配置 (train)
- `base_lr`: 基础学习率
- `epochs`: 总训练轮数
- `patience`: 早停耐心值
- `optimizer`: 优化器类型
- `max_grad_norm`: 最大梯度范数
## 主要特性
1. **模块化设计**: 清晰的数据加载器、训练器、模型分离
2. **配置驱动**: 使用YAML配置文件易于调整参数
3. **统一接口**: 通过run.py统一调用不同模型
4. **完整日志**: 支持文件和控制台日志输出
5. **TensorBoard支持**: 训练过程可视化
6. **检查点管理**: 自动保存和加载最佳模型
## 支持的模型
- **STDE_GT**: 用于北京GM传感器图数据
- **STDE_WRS**: 用于WRS传感器图数据
- **STDE_ZGC**: 用于ZGC传感器图数据
## 数据格式
项目支持两种数据格式:
1. **BJ格式**: 包含flow.npz文件适用于北京数据集
2. **标准格式**: 包含train.npz、val.npz、test.npz文件
## 日志和输出
- 训练日志保存在`logs/`目录
- 模型检查点保存在`checkpoints/`目录
- TensorBoard日志保存在`runs/`目录
- 预测结果保存在检查点目录的`results/`子目录
## 注意事项
1. 确保数据集目录结构正确
2. 传感器图文件路径配置正确
3. 根据硬件配置调整批处理大小
4. 训练过程中会自动创建必要的目录
## 许可证
本项目采用MIT许可证详见LICENSE文件。

5
configs/__init__.py Normal file
View File

@ -0,0 +1,5 @@
# -*- coding: utf-8 -*-
"""
配置文件模块
包含STDEN项目的各种配置
"""

85
configs/stde_gt.yaml Normal file
View File

@ -0,0 +1,85 @@
# STDE_GT模型配置文件
# 用于北京GM传感器图数据的时空扩散方程网络模型
# 基础配置
model_name: "stde_gt"
random_seed: 2021
log_level: "INFO"
log_base_dir: "logs/BJ_GM"
# 数据配置
data:
# 数据集目录
dataset_dir: "data/BJ_GM"
# 批处理大小
batch_size: 32
# 验证集批处理大小
val_batch_size: 32
# 传感器图邻接矩阵文件
graph_pkl_filename: "data/sensor_graph/adj_GM.npy"
# 模型配置
model:
# 输入序列长度
seq_len: 12
# 预测时间步数
horizon: 12
# 输入特征维度
input_dim: 1
# 输出特征维度
output_dim: 1
# 潜在空间维度
latent_dim: 4
# 轨迹采样数量
n_traj_samples: 3
# ODE求解方法
ode_method: "dopri5"
# ODE求解器绝对误差容差
odeint_atol: 0.00001
# ODE求解器相对误差容差
odeint_rtol: 0.00001
# RNN隐藏单元数量
rnn_units: 64
# RNN层数
num_rnn_layers: 1
# 图卷积步数
gcn_step: 2
# 滤波器类型 (default/unkP/IncP)
filter_type: "default"
# 循环神经网络类型
recg_type: "gru"
# 是否保存潜在表示
save_latent: false
# 是否记录函数评估次数
nfe: false
# L1正则化衰减
l1_decay: 0
# 训练配置
train:
# 基础学习率
base_lr: 0.01
# Dropout比率
dropout: 0
# 加载的检查点epoch
load: 0
# 当前训练epoch
epoch: 0
# 总训练epoch数
epochs: 100
# 收敛阈值
epsilon: 1.0e-3
# 学习率衰减比率
lr_decay_ratio: 0.1
# 最大梯度范数
max_grad_norm: 5
# 最小学习率
min_learning_rate: 2.0e-06
# 优化器类型
optimizer: "adam"
# 早停耐心值
patience: 20
# 学习率衰减步数
steps: [20, 30, 40, 50]
# 测试频率每N个epoch测试一次
test_every_n_epochs: 5

85
configs/stde_wrs.yaml Normal file
View File

@ -0,0 +1,85 @@
# STDE_WRS模型配置文件
# 用于WRS传感器图数据的时空扩散方程网络模型
# 基础配置
model_name: "stde_wrs"
random_seed: 2021
log_level: "INFO"
log_base_dir: "logs/WRS"
# 数据配置
data:
# 数据集目录
dataset_dir: "data/WRS"
# 批处理大小
batch_size: 32
# 验证集批处理大小
val_batch_size: 32
# 传感器图邻接矩阵文件
graph_pkl_filename: "data/sensor_graph/adj_WRS.npy"
# 模型配置
model:
# 输入序列长度
seq_len: 12
# 预测时间步数
horizon: 12
# 输入特征维度
input_dim: 1
# 输出特征维度
output_dim: 1
# 潜在空间维度
latent_dim: 4
# 轨迹采样数量
n_traj_samples: 3
# ODE求解方法
ode_method: "dopri5"
# ODE求解器绝对误差容差
odeint_atol: 0.00001
# ODE求解器相对误差容差
odeint_rtol: 0.00001
# RNN隐藏单元数量
rnn_units: 64
# RNN层数
num_rnn_layers: 1
# 图卷积步数
gcn_step: 2
# 滤波器类型 (default/unkP/IncP)
filter_type: "default"
# 循环神经网络类型
recg_type: "gru"
# 是否保存潜在表示
save_latent: false
# 是否记录函数评估次数
nfe: false
# L1正则化衰减
l1_decay: 0
# 训练配置
train:
# 基础学习率
base_lr: 0.01
# Dropout比率
dropout: 0
# 加载的检查点epoch
load: 0
# 当前训练epoch
epoch: 0
# 总训练epoch数
epochs: 100
# 收敛阈值
epsilon: 1.0e-3
# 学习率衰减比率
lr_decay_ratio: 0.1
# 最大梯度范数
max_grad_norm: 5
# 最小学习率
min_learning_rate: 2.0e-06
# 优化器类型
optimizer: "adam"
# 早停耐心值
patience: 20
# 学习率衰减步数
steps: [20, 30, 40, 50]
# 测试频率每N个epoch测试一次
test_every_n_epochs: 5

85
configs/stde_zgc.yaml Normal file
View File

@ -0,0 +1,85 @@
# STDE_ZGC模型配置文件
# 用于ZGC传感器图数据的时空扩散方程网络模型
# 基础配置
model_name: "stde_zgc"
random_seed: 2021
log_level: "INFO"
log_base_dir: "logs/ZGC"
# 数据配置
data:
# 数据集目录
dataset_dir: "data/ZGC"
# 批处理大小
batch_size: 32
# 验证集批处理大小
val_batch_size: 32
# 传感器图邻接矩阵文件
graph_pkl_filename: "data/sensor_graph/adj_ZGC.npy"
# 模型配置
model:
# 输入序列长度
seq_len: 12
# 预测时间步数
horizon: 12
# 输入特征维度
input_dim: 1
# 输出特征维度
output_dim: 1
# 潜在空间维度
latent_dim: 4
# 轨迹采样数量
n_traj_samples: 3
# ODE求解方法
ode_method: "dopri5"
# ODE求解器绝对误差容差
odeint_atol: 0.00001
# ODE求解器相对误差容差
odeint_rtol: 0.00001
# RNN隐藏单元数量
rnn_units: 64
# RNN层数
num_rnn_layers: 1
# 图卷积步数
gcn_step: 2
# 滤波器类型 (default/unkP/IncP)
filter_type: "default"
# 循环神经网络类型
recg_type: "gru"
# 是否保存潜在表示
save_latent: false
# 是否记录函数评估次数
nfe: false
# L1正则化衰减
l1_decay: 0
# 训练配置
train:
# 基础学习率
base_lr: 0.01
# Dropout比率
dropout: 0
# 加载的检查点epoch
load: 0
# 当前训练epoch
epoch: 0
# 总训练epoch数
epochs: 100
# 收敛阈值
epsilon: 1.0e-3
# 学习率衰减比率
lr_decay_ratio: 0.1
# 最大梯度范数
max_grad_norm: 5
# 最小学习率
min_learning_rate: 2.0e-06
# 优化器类型
optimizer: "adam"
# 早停耐心值
patience: 20
# 学习率衰减步数
steps: [20, 30, 40, 50]
# 测试频率每N个epoch测试一次
test_every_n_epochs: 5

9
dataloader/__init__.py Normal file
View File

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
"""
数据加载器模块
包含STDEN项目的数据加载和处理逻辑
"""
from .stden_dataloader import STDENDataloader
__all__ = ['STDENDataloader']

View File

@ -0,0 +1,215 @@
# -*- coding: utf-8 -*-
"""
STDEN数据加载器
负责数据的加载预处理和批处理
"""
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import logging
class STDENDataset(Dataset):
"""STDEN数据集类"""
def __init__(self, x_data, y_data, sequence_length, horizon):
"""
初始化数据集
Args:
x_data: 输入数据
y_data: 标签数据
sequence_length: 输入序列长度
horizon: 预测时间步数
"""
self.x_data = x_data
self.y_data = y_data
self.sequence_length = sequence_length
self.horizon = horizon
self.num_samples = len(x_data) - sequence_length - horizon + 1
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 获取输入序列
x_seq = self.x_data[idx:idx + self.sequence_length]
# 获取目标序列
y_seq = self.y_data[idx + self.sequence_length:idx + self.sequence_length + self.horizon]
return torch.FloatTensor(x_seq), torch.FloatTensor(y_seq)
class StandardScaler:
"""数据标准化器"""
def __init__(self, mean=None, std=None):
self.mean = mean
self.std = std
def fit(self, data):
"""拟合数据,计算均值和标准差"""
self.mean = np.mean(data, axis=0)
self.std = np.std(data, axis=0)
# 避免除零
self.std[self.std == 0] = 1.0
def transform(self, data):
"""标准化数据"""
return (data - self.mean) / self.std
def inverse_transform(self, data):
"""反标准化数据"""
return (data * self.std) + self.mean
class STDENDataloader:
"""STDEN数据加载器主类"""
def __init__(self, config):
"""
初始化数据加载器
Args:
config: 配置字典
"""
self.config = config
self.logger = logging.getLogger('STDEN')
# 数据配置
self.dataset_dir = config['data']['dataset_dir']
self.batch_size = config['data']['batch_size']
self.val_batch_size = config['data'].get('val_batch_size', self.batch_size)
self.sequence_length = config['model']['seq_len']
self.horizon = config['model']['horizon']
# 加载数据
self.data = self._load_dataset()
self.scaler = self.data['scaler']
# 创建数据加载器
self.train_loader = self._create_dataloader(
self.data['x_train'], self.data['y_train'],
self.batch_size, shuffle=True
)
self.val_loader = self._create_dataloader(
self.data['x_val'], self.data['y_val'],
self.val_batch_size, shuffle=False
)
self.test_loader = self._create_dataloader(
self.data['x_test'], self.data['y_test'],
self.val_batch_size, shuffle=False
)
self.logger.info(f"数据加载完成 - 训练集: {len(self.data['x_train'])} 样本")
self.logger.info(f"验证集: {len(self.data['x_val'])} 样本, 测试集: {len(self.data['x_test'])} 样本")
def _load_dataset(self):
"""加载数据集"""
dataset_path = Path(self.dataset_dir)
if not dataset_path.exists():
raise FileNotFoundError(f"数据集目录不存在: {self.dataset_dir}")
# 检查数据集类型BJ_GM或其他
if 'BJ' in self.dataset_dir:
return self._load_bj_dataset(dataset_path)
else:
return self._load_standard_dataset(dataset_path)
def _load_bj_dataset(self, dataset_path):
"""加载BJ_GM数据集"""
flow_file = dataset_path / 'flow.npz'
if not flow_file.exists():
raise FileNotFoundError(f"BJ数据集文件不存在: {flow_file}")
data = dict(np.load(flow_file))
# 提取训练、验证、测试数据
x_train = data['x_train']
y_train = data['y_train']
x_val = data['x_val']
y_val = data['y_val']
x_test = data['x_test']
y_test = data['y_test']
return self._process_data(x_train, y_train, x_val, y_val, x_test, y_test)
def _load_standard_dataset(self, dataset_path):
"""加载标准格式数据集"""
data = {}
for category in ['train', 'val', 'test']:
cat_file = dataset_path / f"{category}.npz"
if not cat_file.exists():
raise FileNotFoundError(f"数据集文件不存在: {cat_file}")
cat_data = np.load(cat_file)
data[f'x_{category}'] = cat_data['x']
data[f'y_{category}'] = cat_data['y']
return self._process_data(
data['x_train'], data['y_train'],
data['x_val'], data['y_val'],
data['x_test'], data['y_test']
)
def _process_data(self, x_train, y_train, x_val, y_val, x_test, y_test):
"""处理数据(标准化等)"""
# 创建标准化器
scaler = StandardScaler()
scaler.fit(x_train)
# 标准化所有数据
x_train_scaled = scaler.transform(x_train)
y_train_scaled = scaler.transform(y_train)
x_val_scaled = scaler.transform(x_val)
y_val_scaled = scaler.transform(y_val)
x_test_scaled = scaler.transform(x_test)
y_test_scaled = scaler.transform(y_test)
return {
'x_train': x_train_scaled,
'y_train': y_train_scaled,
'x_val': x_val_scaled,
'y_val': y_val_scaled,
'x_test': x_test_scaled,
'y_test': y_test_scaled,
'scaler': scaler
}
def _create_dataloader(self, x_data, y_data, batch_size, shuffle=False):
"""创建PyTorch数据加载器"""
dataset = STDENDataset(x_data, y_data, self.sequence_length, self.horizon)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=0, # 避免多进程问题
drop_last=False
)
def get_data_loaders(self):
"""获取所有数据加载器"""
return {
'train': self.train_loader,
'val': self.val_loader,
'test': self.test_loader
}
def get_scaler(self):
"""获取标准化器"""
return self.scaler
def get_data_info(self):
"""获取数据信息"""
return {
'input_dim': self.data['x_train'].shape[-1],
'output_dim': self.data['y_train'].shape[-1],
'sequence_length': self.sequence_length,
'horizon': self.horizon,
'num_nodes': self.data['x_train'].shape[1] if len(self.data['x_train'].shape) > 1 else 1
}

5
examples/__init__.py Normal file
View File

@ -0,0 +1,5 @@
# -*- coding: utf-8 -*-
"""
示例模块
包含STDEN项目的使用示例
"""

115
examples/train_example.py Normal file
View File

@ -0,0 +1,115 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
STDEN项目训练示例
演示如何使用新的项目结构进行模型训练
"""
import sys
from pathlib import Path
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))
from lib.logger import setup_logger
from dataloader.stden_dataloader import STDENDataloader
from trainer.stden_trainer import STDENTrainer
import yaml
def main():
"""主函数示例"""
# 配置字典也可以从YAML文件加载
config = {
'model_name': 'stde_gt',
'log_level': 'INFO',
'log_base_dir': 'logs/example',
'device': 'cpu', # 或 'cuda'
'data': {
'dataset_dir': 'data/BJ_GM',
'batch_size': 16,
'val_batch_size': 16,
'graph_pkl_filename': 'data/sensor_graph/adj_GM.npy'
},
'model': {
'seq_len': 12,
'horizon': 12,
'input_dim': 1,
'output_dim': 1,
'latent_dim': 4,
'n_traj_samples': 3,
'ode_method': 'dopri5',
'odeint_atol': 0.00001,
'odeint_rtol': 0.00001,
'rnn_units': 64,
'num_rnn_layers': 1,
'gcn_step': 2,
'filter_type': 'default',
'recg_type': 'gru',
'save_latent': False,
'nfe': False,
'l1_decay': 0
},
'train': {
'base_lr': 0.01,
'dropout': 0,
'load': 0,
'epoch': 0,
'epochs': 50, # 减少训练轮数用于示例
'epsilon': 1.0e-3,
'lr_decay_ratio': 0.1,
'max_grad_norm': 5,
'min_learning_rate': 2.0e-06,
'optimizer': 'adam',
'patience': 10,
'steps': [10, 20, 30],
'test_every_n_epochs': 5
}
}
try:
# 设置日志
logger = setup_logger(config)
logger.info("开始STDEN项目训练示例")
# 注意:这里需要实际的邻接矩阵数据
# 为了示例,我们创建一个虚拟的邻接矩阵
import numpy as np
config['adj_matrix'] = np.random.rand(10, 10) # 10x10的随机邻接矩阵
# 创建数据加载器
logger.info("创建数据加载器...")
try:
dataloader = STDENDataloader(config)
logger.info("数据加载器创建成功")
except FileNotFoundError as e:
logger.warning(f"数据加载器创建失败(预期行为,因为示例中没有实际数据): {e}")
logger.info("继续演示项目结构...")
return
# 创建训练器
logger.info("创建训练器...")
trainer = STDENTrainer(config, dataloader)
# 开始训练
logger.info("开始训练...")
trainer.train()
# 评估模型
logger.info("开始评估...")
metrics = trainer.evaluate(save_predictions=True)
logger.info("训练示例完成!")
except Exception as e:
logger.error(f"训练过程中发生错误: {e}")
raise
if __name__ == '__main__':
main()

0
lib/__init__.py Normal file
View File

83
lib/logger.py Normal file
View File

@ -0,0 +1,83 @@
# -*- coding: utf-8 -*-
"""
日志配置模块
提供统一的日志格式和配置
"""
import logging
import os
from pathlib import Path
from datetime import datetime
def setup_logger(config):
"""
设置日志记录器
Args:
config: 配置字典包含日志相关配置
Returns:
logger: 配置好的日志记录器
"""
# 获取日志配置
log_level = config.get('log_level', 'INFO')
log_base_dir = config.get('log_base_dir', 'logs')
# 创建日志目录
log_dir = Path(log_base_dir)
log_dir.mkdir(parents=True, exist_ok=True)
# 创建日志文件名(包含时间戳)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
model_name = config.get('model_name', 'stden')
log_filename = f"{model_name}_{timestamp}.log"
log_path = log_dir / log_filename
# 创建日志记录器
logger = logging.getLogger('STDEN')
logger.setLevel(getattr(logging, log_level.upper()))
# 清除现有的处理器
logger.handlers.clear()
# 创建文件处理器
file_handler = logging.FileHandler(log_path, encoding='utf-8')
file_handler.setLevel(getattr(logging, log_level.upper()))
# 创建控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(getattr(logging, log_level.upper()))
# 创建格式化器
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 设置格式化器
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
# 添加处理器
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# 记录日志配置信息
logger.info(f"日志文件路径: {log_path}")
logger.info(f"日志级别: {log_level}")
return logger
def get_logger(name='STDEN'):
"""
获取已配置的日志记录器
Args:
name: 日志记录器名称
Returns:
logger: 日志记录器
"""
return logging.getLogger(name)

29
lib/metrics.py Normal file
View File

@ -0,0 +1,29 @@
import torch
def masked_mae_loss(y_pred, y_true):
y_true[y_true < 1e-4] = 0
mask = (y_true != 0).float()
mask /= mask.mean() # assign the sample weights of zeros to nonzero-values
loss = torch.abs(y_pred - y_true)
loss = loss * mask
# trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
loss[loss != loss] = 0
return loss.mean()
def masked_mape_loss(y_pred, y_true):
y_true[y_true < 1e-4] = 0
mask = (y_true != 0).float()
mask /= mask.mean()
loss = torch.abs((y_pred - y_true) / y_true)
loss = loss * mask
loss[loss != loss] = 0
return loss.mean()
def masked_rmse_loss(y_pred, y_true):
y_true[y_true < 1e-4] = 0
mask = (y_true != 0).float()
mask /= mask.mean()
loss = torch.pow(y_pred - y_true, 2)
loss = loss * mask
loss[loss != loss] = 0
return torch.sqrt(loss.mean())

228
lib/utils.py Normal file
View File

@ -0,0 +1,228 @@
import logging
import numpy as np
import os
import time
import scipy.sparse as sp
import sys
import torch
import torch.nn as nn
class DataLoader(object):
def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False):
"""
:param xs:
:param ys:
:param batch_size:
:param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size.
"""
self.batch_size = batch_size
self.current_ind = 0
if pad_with_last_sample:
num_padding = (batch_size - (len(xs) % batch_size)) % batch_size
x_padding = np.repeat(xs[-1:], num_padding, axis=0)
y_padding = np.repeat(ys[-1:], num_padding, axis=0)
xs = np.concatenate([xs, x_padding], axis=0)
ys = np.concatenate([ys, y_padding], axis=0)
self.size = len(xs)
self.num_batch = int(self.size // self.batch_size)
if shuffle:
permutation = np.random.permutation(self.size)
xs, ys = xs[permutation], ys[permutation]
self.xs = xs
self.ys = ys
def get_iterator(self):
self.current_ind = 0
def _wrapper():
while self.current_ind < self.num_batch:
start_ind = self.batch_size * self.current_ind
end_ind = min(self.size, self.batch_size * (self.current_ind + 1))
x_i = self.xs[start_ind: end_ind, ...]
y_i = self.ys[start_ind: end_ind, ...]
yield (x_i, y_i)
self.current_ind += 1
return _wrapper()
class StandardScaler:
"""
Standard the input
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def transform(self, data):
return (data - self.mean) / self.std
def inverse_transform(self, data):
return (data * self.std) + self.mean
def calculate_random_walk_matrix(adj_mx):
adj_mx = sp.coo_matrix(adj_mx)
d = np.array(adj_mx.sum(1))
d_inv = np.power(d, -1).flatten()
d_inv[np.isinf(d_inv)] = 0.
d_mat_inv = sp.diags(d_inv)
random_walk_mx = d_mat_inv.dot(adj_mx).tocoo()
return random_walk_mx
def config_logging(log_dir, log_filename='info.log', level=logging.INFO):
# Add file handler and stdout handler
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Create the log directory if necessary.
try:
os.makedirs(log_dir)
except OSError:
pass
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
file_handler.setFormatter(formatter)
file_handler.setLevel(level=level)
# Add console handler.
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(console_formatter)
console_handler.setLevel(level=level)
logging.basicConfig(handlers=[file_handler, console_handler], level=level)
def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO):
logger = logging.getLogger(name)
logger.setLevel(level)
# Add file handler and stdout handler
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
file_handler.setFormatter(formatter)
# Add console handler.
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(console_formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# Add google cloud log handler
logger.info('Log directory: %s', log_dir)
return logger
def get_log_dir(kwargs):
log_dir = kwargs['train'].get('log_dir')
if log_dir is None:
batch_size = kwargs['data'].get('batch_size')
filter_type = kwargs['model'].get('filter_type')
gcn_step = kwargs['model'].get('gcn_step')
horizon = kwargs['model'].get('horizon')
latent_dim = kwargs['model'].get('latent_dim')
n_traj_samples = kwargs['model'].get('n_traj_samples')
ode_method = kwargs['model'].get('ode_method')
seq_len = kwargs['model'].get('seq_len')
rnn_units = kwargs['model'].get('rnn_units')
recg_type = kwargs['model'].get('recg_type')
if filter_type == 'unkP':
filter_type_abbr = 'UP'
elif filter_type == 'IncP':
filter_type_abbr = 'NV'
else:
filter_type_abbr = 'DF'
run_id = 'STDEN_%s-%d_%s-%d_L-%d_N-%d_M-%s_bs-%d_%d-%d_%s/' % (
recg_type, rnn_units, filter_type_abbr, gcn_step, latent_dim, n_traj_samples, ode_method, batch_size, seq_len, horizon, time.strftime('%m%d%H%M%S'))
base_dir = kwargs.get('log_base_dir')
log_dir = os.path.join(base_dir, run_id)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
return log_dir
def load_dataset(dataset_dir, batch_size, val_batch_size=None, **kwargs):
if('BJ' in dataset_dir):
data = dict(np.load(os.path.join(dataset_dir, 'flow.npz'))) # convert readonly NpzFile to writable dict Object
for category in ['train', 'val', 'test']:
data['x_' + category] = data['x_' + category] #[..., :4] # ignore the time index
else:
data = {}
for category in ['train', 'val', 'test']:
cat_data = np.load(os.path.join(dataset_dir, category + '.npz'))
data['x_' + category] = cat_data['x']
data['y_' + category] = cat_data['y']
scaler = StandardScaler(mean=data['x_train'].mean(), std=data['x_train'].std())
# Data format
for category in ['train', 'val', 'test']:
data['x_' + category] = scaler.transform(data['x_' + category])
data['y_' + category] = scaler.transform(data['y_' + category])
data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True)
data['val_loader'] = DataLoader(data['x_val'], data['y_val'], val_batch_size, shuffle=False)
data['test_loader'] = DataLoader(data['x_test'], data['y_test'], val_batch_size, shuffle=False)
data['scaler'] = scaler
return data
def load_graph_data(pkl_filename):
adj_mx = np.load(pkl_filename)
return adj_mx
def graph_grad(adj_mx):
"""Fetch the graph gradient operator."""
num_nodes = adj_mx.shape[0]
num_edges = (adj_mx > 0.).sum()
grad = torch.zeros(num_nodes, num_edges)
e = 0
for i in range(num_nodes):
for j in range(num_nodes):
if adj_mx[i, j] == 0:
continue
grad[i, e] = 1.
grad[j, e] = -1.
e += 1
return grad
def init_network_weights(net, std = 0.1):
"""
Just for nn.Linear net.
"""
for m in net.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0, std=std)
nn.init.constant_(m.bias, val=0)
def split_last_dim(data):
last_dim = data.size()[-1]
last_dim = last_dim//2
res = data[..., :last_dim], data[..., last_dim:]
return res
def get_device(tensor):
device = torch.device("cpu")
if tensor.is_cuda:
device = tensor.get_device()
return device
def sample_standard_gaussian(mu, sigma):
device = get_device(mu)
d = torch.distributions.normal.Normal(torch.Tensor([0.]).to(device), torch.Tensor([1.]).to(device))
r = d.sample(mu.size()).squeeze(-1)
return r * sigma.float() + mu.float()
def create_net(n_inputs, n_outputs, n_layers = 0,
n_units = 100, nonlinear = nn.Tanh):
layers = [nn.Linear(n_inputs, n_units)]
for i in range(n_layers):
layers.append(nonlinear())
layers.append(nn.Linear(n_units, n_units))
layers.append(nonlinear())
layers.append(nn.Linear(n_units, n_outputs))
return nn.Sequential(*layers)

0
model/__init__.py Normal file
View File

49
model/diffeq_solver.py Normal file
View File

@ -0,0 +1,49 @@
import torch
import torch.nn as nn
import time
from torchdiffeq import odeint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DiffeqSolver(nn.Module):
def __init__(self, odefunc, method, latent_dim,
odeint_rtol = 1e-4, odeint_atol = 1e-5):
nn.Module.__init__(self)
self.ode_method = method
self.odefunc = odefunc
self.latent_dim = latent_dim
self.rtol = odeint_rtol
self.atol = odeint_atol
def forward(self, first_point, time_steps_to_pred):
"""
Decoder the trajectory through the ODE Solver.
:param time_steps_to_pred: horizon
:param first_point: (n_traj_samples, batch_size, num_nodes * latent_dim)
:return: pred_y: # shape (horizon, n_traj_samples, batch_size, self.num_nodes * self.output_dim)
"""
n_traj_samples, batch_size = first_point.size()[0], first_point.size()[1]
first_point = first_point.reshape(n_traj_samples * batch_size, -1) # reduce the complexity by merging dimension
# pred_y shape: (horizon, n_traj_samples * batch_size, num_nodes * latent_dim)
start_time = time.time()
self.odefunc.nfe = 0
pred_y = odeint(self.odefunc,
first_point,
time_steps_to_pred,
rtol=self.rtol,
atol=self.atol,
method=self.ode_method)
time_fe = time.time() - start_time
# pred_y shape: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim)
pred_y = pred_y.reshape(pred_y.size()[0], n_traj_samples, batch_size, -1)
# assert(pred_y.size()[1] == n_traj_samples)
# assert(pred_y.size()[2] == batch_size)
return pred_y, (self.odefunc.nfe, time_fe)

165
model/ode_func.py Normal file
View File

@ -0,0 +1,165 @@
import numpy as np
import torch
import torch.nn as nn
from lib import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LayerParams:
def __init__(self, rnn_network: nn.Module, layer_type: str):
self._rnn_network = rnn_network
self._params_dict = {}
self._biases_dict = {}
self._type = layer_type
def get_weights(self, shape):
if shape not in self._params_dict:
nn_param = nn.Parameter(torch.empty(*shape, device=device))
nn.init.xavier_normal_(nn_param)
self._params_dict[shape] = nn_param
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
nn_param)
return self._params_dict[shape]
def get_biases(self, length, bias_start=0.0):
if length not in self._biases_dict:
biases = nn.Parameter(torch.empty(length, device=device))
nn.init.constant_(biases, bias_start)
self._biases_dict[length] = biases
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
biases)
return self._biases_dict[length]
class ODEFunc(nn.Module):
def __init__(self, num_units, latent_dim, adj_mx, gcn_step, num_nodes,
gen_layers=1, nonlinearity='tanh', filter_type="default"):
"""
:param num_units: dimensionality of the hidden layers
:param latent_dim: dimensionality used for ODE (input and output). Analog of a continous latent state
:param adj_mx:
:param gcn_step:
:param num_nodes:
:param gen_layers: hidden layers in each ode func.
:param nonlinearity:
:param filter_type: default
:param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates.
"""
super(ODEFunc, self).__init__()
self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu
self._num_nodes = num_nodes
self._num_units = num_units # hidden dimension
self._latent_dim = latent_dim
self._gen_layers = gen_layers
self.nfe = 0
self._filter_type = filter_type
if(self._filter_type == "unkP"):
ode_func_net = utils.create_net(latent_dim, latent_dim, n_units=num_units)
utils.init_network_weights(ode_func_net)
self.gradient_net = ode_func_net
else:
self._gcn_step = gcn_step
self._gconv_params = LayerParams(self, 'gconv')
self._supports = []
supports = []
supports.append(utils.calculate_random_walk_matrix(adj_mx).T)
supports.append(utils.calculate_random_walk_matrix(adj_mx.T).T)
for support in supports:
self._supports.append(self._build_sparse_matrix(support))
@staticmethod
def _build_sparse_matrix(L):
L = L.tocoo()
indices = np.column_stack((L.row, L.col))
# this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L)
indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))]
L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device)
return L
def forward(self, t_local, y, backwards = False):
"""
Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point
t_local: current time point
y: value at the current time point, shape (B, num_nodes * latent_dim)
:return
- Output: A `2-D` tensor with shape `(B, num_nodes * latent_dim)`.
"""
self.nfe += 1
grad = self.get_ode_gradient_nn(t_local, y)
if backwards:
grad = -grad
return grad
def get_ode_gradient_nn(self, t_local, inputs):
if(self._filter_type == "unkP"):
grad = self._fc(inputs)
elif (self._filter_type == "IncP"):
grad = - self.ode_func_net(inputs)
else: # default is diffusion process
# theta shape: (B, num_nodes * latent_dim)
theta = torch.sigmoid(self._gconv(inputs, self._latent_dim, bias_start=1.0))
grad = - theta * self.ode_func_net(inputs)
return grad
def ode_func_net(self, inputs):
c = inputs
for i in range(self._gen_layers):
c = self._gconv(c, self._num_units)
c = self._activation(c)
c = self._gconv(c, self._latent_dim)
c = self._activation(c)
return c
def _fc(self, inputs):
batch_size = inputs.size()[0]
grad = self.gradient_net(inputs.view(batch_size * self._num_nodes, self._latent_dim))
return grad.reshape(batch_size, self._num_nodes * self._latent_dim) # (batch_size, num_nodes, latent_dim)
@staticmethod
def _concat(x, x_):
x_ = x_.unsqueeze(0)
return torch.cat([x, x_], dim=0)
def _gconv(self, inputs, output_size, bias_start=0.0):
# Reshape input and state to (batch_size, num_nodes, input_dim/state_dim)
batch_size = inputs.shape[0]
inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1))
# state = torch.reshape(state, (batch_size, self._num_nodes, -1))
# inputs_and_state = torch.cat([inputs, state], dim=2)
input_size = inputs.size(2)
x = inputs
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
x0 = torch.reshape(x0, shape=[self._num_nodes, input_size * batch_size])
x = torch.unsqueeze(x0, 0)
if self._gcn_step == 0:
pass
else:
for support in self._supports:
x1 = torch.sparse.mm(support, x0)
x = self._concat(x, x1)
for k in range(2, self._gcn_step + 1):
x2 = 2 * torch.sparse.mm(support, x1) - x0
x = self._concat(x, x2)
x1, x0 = x2, x1
num_matrices = len(self._supports) * self._gcn_step + 1 # Adds for x itself.
x = torch.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size])
x = x.permute(3, 1, 2, 0) # (batch_size, num_nodes, input_size, order)
x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices])
weights = self._gconv_params.get_weights((input_size * num_matrices, output_size))
x = torch.matmul(x, weights) # (batch_size * self._num_nodes, output_size)
biases = self._gconv_params.get_biases(output_size, bias_start)
x += biases
# Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim)
return torch.reshape(x, [batch_size, self._num_nodes * output_size])

206
model/stden_model.py Normal file
View File

@ -0,0 +1,206 @@
import time
import torch
import torch.nn as nn
from torch.nn.modules.rnn import GRU
from model.ode_func import ODEFunc
from model.diffeq_solver import DiffeqSolver
from lib import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class EncoderAttrs:
def __init__(self, adj_mx, **model_kwargs):
self.adj_mx = adj_mx
self.num_nodes = adj_mx.shape[0]
self.num_edges = (adj_mx > 0.).sum()
self.gcn_step = int(model_kwargs.get('gcn_step', 2))
self.filter_type = model_kwargs.get('filter_type', 'default')
self.num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1))
self.rnn_units = int(model_kwargs.get('rnn_units'))
self.latent_dim = int(model_kwargs.get('latent_dim', 4))
class STDENModel(nn.Module, EncoderAttrs):
def __init__(self, adj_mx, logger, **model_kwargs):
nn.Module.__init__(self)
EncoderAttrs.__init__(self, adj_mx, **model_kwargs)
self._logger = logger
####################################################
# recognition net
####################################################
self.encoder_z0 = Encoder_z0_RNN(adj_mx, **model_kwargs)
####################################################
# ode solver
####################################################
self.n_traj_samples = int(model_kwargs.get('n_traj_samples', 1))
self.ode_method = model_kwargs.get('ode_method', 'dopri5')
self.atol = float(model_kwargs.get('odeint_atol', 1e-4))
self.rtol = float(model_kwargs.get('odeint_rtol', 1e-3))
self.num_gen_layer = int(model_kwargs.get('gen_layers', 1))
self.ode_gen_dim = int(model_kwargs.get('gen_dim', 64))
ode_set_str = "ODE setting --latent {} --samples {} --method {} \
--atol {:6f} --rtol {:6f} --gen_layer {} --gen_dim {}".format(\
self.latent_dim, self.n_traj_samples, self.ode_method, \
self.atol, self.rtol, self.num_gen_layer, self.ode_gen_dim)
odefunc = ODEFunc(self.ode_gen_dim, # hidden dimension
self.latent_dim,
adj_mx,
self.gcn_step,
self.num_nodes,
filter_type=self.filter_type
).to(device)
self.diffeq_solver = DiffeqSolver(odefunc,
self.ode_method,
self.latent_dim,
odeint_rtol=self.rtol,
odeint_atol=self.atol
)
self._logger.info(ode_set_str)
self.save_latent = bool(model_kwargs.get('save_latent', False))
self.latent_feat = None # used to extract the latent feature
####################################################
# decoder
####################################################
self.horizon = int(model_kwargs.get('horizon', 1))
self.out_feat = int(model_kwargs.get('output_dim', 1))
self.decoder = Decoder(
self.out_feat,
adj_mx,
self.num_nodes,
self.num_edges,
).to(device)
##########################################
def forward(self, inputs, labels=None, batches_seen=None):
"""
seq2seq forward pass
:param inputs: shape (seq_len, batch_size, num_edges * input_dim)
:param labels: shape (horizon, batch_size, num_edges * output_dim)
:param batches_seen: batches seen till now
:return: outputs: (self.horizon, batch_size, self.num_edges * self.output_dim)
"""
perf_time = time.time()
# shape: [1, batch, num_nodes * latent_dim]
first_point_mu, first_point_std = self.encoder_z0(inputs)
self._logger.debug("Recognition complete with {:.1f}s".format(time.time() - perf_time))
# sample 'n_traj_samples' trajectory
perf_time = time.time()
means_z0 = first_point_mu.repeat(self.n_traj_samples, 1, 1)
sigma_z0 = first_point_std.repeat(self.n_traj_samples, 1, 1)
first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0)
time_steps_to_predict = torch.arange(start=0, end=self.horizon, step=1).float().to(device)
time_steps_to_predict = time_steps_to_predict / len(time_steps_to_predict)
# Shape of sol_ys (horizon, n_traj_samples, batch_size, self.num_nodes * self.latent_dim)
sol_ys, fe = self.diffeq_solver(first_point_enc, time_steps_to_predict)
self._logger.debug("ODE solver complete with {:.1f}s".format(time.time() - perf_time))
if(self.save_latent):
# Shape of latent_feat (horizon, batch_size, self.num_nodes * self.latent_dim)
self.latent_feat = torch.mean(sol_ys.detach(), axis=1)
perf_time = time.time()
outputs = self.decoder(sol_ys)
self._logger.debug("Decoder complete with {:.1f}s".format(time.time() - perf_time))
if batches_seen == 0:
self._logger.info(
"Total trainable parameters {}".format(count_parameters(self))
)
return outputs, fe
class Encoder_z0_RNN(nn.Module, EncoderAttrs):
def __init__(self, adj_mx, **model_kwargs):
nn.Module.__init__(self)
EncoderAttrs.__init__(self, adj_mx, **model_kwargs)
self.recg_type = model_kwargs.get('recg_type', 'gru') # gru
if(self.recg_type == 'gru'):
# gru settings
self.input_dim = int(model_kwargs.get('input_dim', 1))
self.gru_rnn = GRU(self.input_dim, self.rnn_units).to(device)
else:
raise NotImplementedError("The recognition net only support 'gru'.")
# hidden to z0 settings
self.inv_grad = utils.graph_grad(adj_mx).transpose(-2, -1)
self.inv_grad[self.inv_grad != 0.] = 0.5
self.hiddens_to_z0 = nn.Sequential(
nn.Linear(self.rnn_units, 50),
nn.Tanh(),
nn.Linear(50, self.latent_dim * 2),)
utils.init_network_weights(self.hiddens_to_z0)
def forward(self, inputs):
"""
encoder forward pass on t time steps
:param inputs: shape (seq_len, batch_size, num_edges * input_dim)
:return: mean, std: # shape (n_samples=1, batch_size, self.latent_dim)
"""
if(self.recg_type == 'gru'):
# shape of outputs: (seq_len, batch, num_senor * rnn_units)
seq_len, batch_size = inputs.size(0), inputs.size(1)
inputs = inputs.reshape(seq_len, batch_size, self.num_edges, self.input_dim)
inputs = inputs.reshape(seq_len, batch_size * self.num_edges, self.input_dim)
outputs, _ = self.gru_rnn(inputs)
last_output = outputs[-1]
# (batch_size, num_edges, rnn_units)
last_output = torch.reshape(last_output, (batch_size, self.num_edges, -1))
last_output = torch.transpose(last_output, (-2, -1))
# (batch_size, num_nodes, rnn_units)
last_output = torch.matmul(last_output, self.inv_grad).transpose(-2, -1)
else:
raise NotImplementedError("The recognition net only support 'gru'.")
mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output))
mean = mean.reshape(batch_size, -1) # (batch_size, num_nodes * latent_dim)
std = std.reshape(batch_size, -1) # (batch_size, num_nodes * latent_dim)
std = std.abs()
assert(not torch.isnan(mean).any())
assert(not torch.isnan(std).any())
return mean.unsqueeze(0), std.unsqueeze(0) # for n_sample traj
class Decoder(nn.Module):
def __init__(self, output_dim, adj_mx, num_nodes, num_edges):
super(Decoder, self).__init__()
self.num_nodes = num_nodes
self.num_edges = num_edges
self.grap_grad = utils.graph_grad(adj_mx)
self.output_dim = output_dim
def forward(self, inputs):
"""
:param inputs: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim)
:return outputs: (horizon, batch_size, num_edges * output_dim), average result of n_traj_samples.
"""
assert(len(inputs.size()) == 4)
horizon, n_traj_samples, batch_size = inputs.size()[:3]
inputs = inputs.reshape(horizon, n_traj_samples, batch_size, self.num_nodes, -1).transpose(-2, -1)
latent_dim = inputs.size(-2)
# transform z with shape `(..., num_nodes)` to f with shape `(..., num_edges)`.
outputs = torch.matmul(inputs, self.grap_grad)
outputs = outputs.reshape(horizon, n_traj_samples, batch_size, latent_dim, self.num_edges, self.output_dim)
outputs = torch.mean(
torch.mean(outputs, axis=3),
axis=1
)
outputs = outputs.reshape(horizon, batch_size, -1)
return outputs

415
model/stden_supervisor.py Normal file
View File

@ -0,0 +1,415 @@
import os
import time
from random import SystemRandom
import numpy as np
import pandas as pd
import torch
from torch.utils.tensorboard import SummaryWriter
from lib import utils
from model.stden_model import STDENModel
from lib.metrics import masked_mae_loss, masked_mape_loss, masked_rmse_loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class STDENSupervisor:
def __init__(self, adj_mx, **kwargs):
self._kwargs = kwargs
self._data_kwargs = kwargs.get('data')
self._model_kwargs = kwargs.get('model')
self._train_kwargs = kwargs.get('train')
self.max_grad_norm = self._train_kwargs.get('max_grad_norm', 1.)
# logging.
self._log_dir = utils.get_log_dir(kwargs)
self._writer = SummaryWriter('runs/' + self._log_dir)
log_level = self._kwargs.get('log_level', 'INFO')
self._logger = utils.get_logger(self._log_dir, __name__, 'info.log', level=log_level)
# data set
self._data = utils.load_dataset(**self._data_kwargs)
self.standard_scaler = self._data['scaler']
self._logger.info('Scaler mean: {:.6f}, std {:.6f}.'.format(self.standard_scaler.mean, self.standard_scaler.std))
self.num_edges = (adj_mx > 0.).sum()
self.input_dim = int(self._model_kwargs.get('input_dim', 1))
self.seq_len = int(self._model_kwargs.get('seq_len')) # for the encoder
self.output_dim = int(self._model_kwargs.get('output_dim', 1))
self.use_curriculum_learning = bool(
self._model_kwargs.get('use_curriculum_learning', False))
self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder
# setup model
stden_model = STDENModel(adj_mx, self._logger, **self._model_kwargs)
self.stden_model = stden_model.cuda() if torch.cuda.is_available() else stden_model
self._logger.info("Model created")
self.experimentID = self._train_kwargs.get('load', 0)
if self.experimentID == 0:
# Make a new experiment ID
self.experimentID = int(SystemRandom().random()*100000)
self.ckpt_path = os.path.join("ckpt/", "experiment_" + str(self.experimentID))
self._epoch_num = self._train_kwargs.get('epoch', 0)
if self._epoch_num > 0:
self._logger.info('Loading model...')
self.load_model()
def save_model(self, epoch):
model_dir = self.ckpt_path
if not os.path.exists(model_dir):
os.makedirs(model_dir)
config = dict(self._kwargs)
config['model_state_dict'] = self.stden_model.state_dict()
config['epoch'] = epoch
model_path = os.path.join(model_dir, 'epo{}.tar'.format(epoch))
torch.save(config, model_path)
self._logger.info("Saved model at {}".format(epoch))
return model_path
def load_model(self):
self._setup_graph()
model_path = os.path.join(self.ckpt_path, 'epo{}.tar'.format(self._epoch_num))
assert os.path.exists(model_path), 'Weights at epoch %d not found' % self._epoch_num
checkpoint = torch.load(model_path, map_location='cpu')
self.stden_model.load_state_dict(checkpoint['model_state_dict'])
self._logger.info("Loaded model at {}".format(self._epoch_num))
def _setup_graph(self):
with torch.no_grad():
self.stden_model.eval()
val_iterator = self._data['val_loader'].get_iterator()
for _, (x, y) in enumerate(val_iterator):
x, y = self._prepare_data(x, y)
output = self.stden_model(x)
break
def train(self, **kwargs):
self._logger.info('Model mode: train')
kwargs.update(self._train_kwargs)
return self._train(**kwargs)
def _train(self, base_lr,
steps, patience=50, epochs=100, lr_decay_ratio=0.1, log_every=1, save_model=1,
test_every_n_epochs=10, epsilon=1e-8, **kwargs):
# steps is used in learning rate - will see if need to use it?
min_val_loss = float('inf')
wait = 0
optimizer = torch.optim.Adam(self.stden_model.parameters(), lr=base_lr, eps=epsilon)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps,
gamma=lr_decay_ratio)
self._logger.info('Start training ...')
# this will fail if model is loaded with a changed batch_size
num_batches = self._data['train_loader'].num_batch
self._logger.info("num_batches: {}".format(num_batches))
batches_seen = num_batches * self._epoch_num
# used for nfe
c = []
res, keys = [], []
for epoch_num in range(self._epoch_num, epochs):
self.stden_model.train()
train_iterator = self._data['train_loader'].get_iterator()
losses = []
start_time = time.time()
c.clear() #nfe
for i, (x, y) in enumerate(train_iterator):
if(i >= num_batches):
break
optimizer.zero_grad()
x, y = self._prepare_data(x, y)
output, fe = self.stden_model(x, y, batches_seen)
if batches_seen == 0:
# this is a workaround to accommodate dynamically registered parameters
optimizer = torch.optim.Adam(self.stden_model.parameters(), lr=base_lr, eps=epsilon)
loss = self._compute_loss(y, output)
self._logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item()))
c.append([*fe, loss.item()])
self._logger.debug(loss.item())
losses.append(loss.item())
batches_seen += 1 # global step in tensorboard
loss.backward()
# gradient clipping
torch.nn.utils.clip_grad_norm_(self.stden_model.parameters(), self.max_grad_norm)
optimizer.step()
del x, y, output, loss # del make these memory no-labeled trash
torch.cuda.empty_cache() # empty_cache() recycle no-labeled trash
# used for nfe
res.append(pd.DataFrame(c, columns=['nfe', 'time', 'err']))
keys.append(epoch_num)
self._logger.info("epoch complete")
lr_scheduler.step()
self._logger.info("evaluating now!")
val_loss, _ = self.evaluate(dataset='val', batches_seen=batches_seen)
end_time = time.time()
self._writer.add_scalar('training loss',
np.mean(losses),
batches_seen)
if (epoch_num % log_every) == log_every - 1:
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}, ' \
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
np.mean(losses), val_loss, lr_scheduler.get_lr()[0],
(end_time - start_time))
self._logger.info(message)
if (epoch_num % test_every_n_epochs) == test_every_n_epochs - 1:
test_loss, _ = self.evaluate(dataset='test', batches_seen=batches_seen)
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f}, ' \
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
np.mean(losses), test_loss, lr_scheduler.get_lr()[0],
(end_time - start_time))
self._logger.info(message)
if val_loss < min_val_loss:
wait = 0
if save_model:
model_file_name = self.save_model(epoch_num)
self._logger.info(
'Val loss decrease from {:.4f} to {:.4f}, '
'saving to {}'.format(min_val_loss, val_loss, model_file_name))
min_val_loss = val_loss
elif val_loss >= min_val_loss:
wait += 1
if wait == patience:
self._logger.warning('Early stopping at epoch: %d' % epoch_num)
break
if bool(self._model_kwargs.get('nfe', False)):
res = pd.concat(res, keys=keys)
# self._logger.info("res.shape: ", res.shape)
res.index.names = ['epoch', 'iter']
filter_type = self._model_kwargs.get('filter_type', 'unknown')
atol = float(self._model_kwargs.get('odeint_atol', 1e-5))
rtol = float(self._model_kwargs.get('odeint_rtol', 1e-5))
nfe_file = os.path.join(
self._data_kwargs.get('dataset_dir', 'data'),
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
res.to_pickle(nfe_file)
# res.to_csv(nfe_file)
def _prepare_data(self, x, y):
x, y = self._get_x_y(x, y)
x, y = self._get_x_y_in_correct_dims(x, y)
return x.to(device), y.to(device)
def _get_x_y(self, x, y):
"""
:param x: shape (batch_size, seq_len, num_edges, input_dim)
:param y: shape (batch_size, horizon, num_edges, input_dim)
:returns x shape (seq_len, batch_size, num_edges, input_dim)
y shape (horizon, batch_size, num_edges, input_dim)
"""
x = torch.from_numpy(x).float()
y = torch.from_numpy(y).float()
self._logger.debug("X: {}".format(x.size()))
self._logger.debug("y: {}".format(y.size()))
x = x.permute(1, 0, 2, 3)
y = y.permute(1, 0, 2, 3)
return x, y
def _get_x_y_in_correct_dims(self, x, y):
"""
:param x: shape (seq_len, batch_size, num_edges, input_dim)
:param y: shape (horizon, batch_size, num_edges, input_dim)
:return: x: shape (seq_len, batch_size, num_edges * input_dim)
y: shape (horizon, batch_size, num_edges * output_dim)
"""
batch_size = x.size(1)
self._logger.debug("size of x {}".format(x.size()))
x = x.view(self.seq_len, batch_size, self.num_edges * self.input_dim)
y = y[..., :self.output_dim].view(self.horizon, batch_size,
self.num_edges * self.output_dim)
return x, y
def _compute_loss(self, y_true, y_predicted):
y_true = self.standard_scaler.inverse_transform(y_true)
y_predicted = self.standard_scaler.inverse_transform(y_predicted)
return masked_mae_loss(y_predicted, y_true)
def _compute_loss_eval(self, y_true, y_predicted):
y_true = self.standard_scaler.inverse_transform(y_true)
y_predicted = self.standard_scaler.inverse_transform(y_predicted)
return masked_mae_loss(y_predicted, y_true).item(), masked_mape_loss(y_predicted, y_true).item(), masked_rmse_loss(y_predicted, y_true).item()
def evaluate(self, dataset='val', batches_seen=0, save=False):
"""
Computes mae rmse mape loss and the predict if save
:return: mean L1Loss
"""
with torch.no_grad():
self.stden_model.eval()
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
mae_losses = []
mape_losses = []
rmse_losses = []
y_dict = None
if(save):
y_truths = []
y_preds = []
for _, (x, y) in enumerate(val_iterator):
x, y = self._prepare_data(x, y)
output, fe = self.stden_model(x)
mae, mape, rmse = self._compute_loss_eval(y, output)
mae_losses.append(mae)
mape_losses.append(mape)
rmse_losses.append(rmse)
if(save):
y_truths.append(y.cpu())
y_preds.append(output.cpu())
mean_loss = {
'mae': np.mean(mae_losses),
'mape': np.mean(mape_losses),
'rmse': np.mean(rmse_losses)
}
self._logger.info('Evaluation: - mae - {:.4f} - mape - {:.4f} - rmse - {:.4f}'.format(mean_loss['mae'], mean_loss['mape'], mean_loss['rmse']))
self._writer.add_scalar('{} loss'.format(dataset), mean_loss['mae'], batches_seen)
if(save):
y_preds = np.concatenate(y_preds, axis=1)
y_truths = np.concatenate(y_truths, axis=1) # concatenate on batch dimension
y_truths_scaled = []
y_preds_scaled = []
# self._logger.debug("y_preds shape: {}, y_truth shape {}".format(y_preds.shape, y_truths.shape))
for t in range(y_preds.shape[0]):
y_truth = self.standard_scaler.inverse_transform(y_truths[t])
y_pred = self.standard_scaler.inverse_transform(y_preds[t])
y_truths_scaled.append(y_truth)
y_preds_scaled.append(y_pred)
y_preds_scaled = np.stack(y_preds_scaled)
y_truths_scaled = np.stack(y_truths_scaled)
y_dict = {'prediction': y_preds_scaled, 'truth': y_truths_scaled}
# save_dir = self._data_kwargs.get('dataset_dir', 'data')
# save_path = os.path.join(save_dir, 'pred.npz')
# np.savez(save_path, prediction=y_preds_scaled, turth=y_truths_scaled)
return mean_loss['mae'], y_dict
def eval_more(self, dataset='val', save=False, seq_len=[3, 6, 9, 12], extract_latent=False):
"""
Computes mae rmse mape loss and the prediction if `save` is set True.
"""
self._logger.info('Model mode: Evaluation')
with torch.no_grad():
self.stden_model.eval()
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
mae_losses = []
mape_losses = []
rmse_losses = []
if(save):
y_truths = []
y_preds = []
if(extract_latent):
latents = []
# used for nfe
c = []
for _, (x, y) in enumerate(val_iterator):
x, y = self._prepare_data(x, y)
output, fe = self.stden_model(x)
mae, mape, rmse = [], [], []
for seq in seq_len:
_mae, _mape, _rmse = self._compute_loss_eval(y[seq-1], output[seq-1])
mae.append(_mae)
mape.append(_mape)
rmse.append(_rmse)
mae_losses.append(mae)
mape_losses.append(mape)
rmse_losses.append(rmse)
c.append([*fe, np.mean(mae)])
if(save):
y_truths.append(y.cpu())
y_preds.append(output.cpu())
if(extract_latent):
latents.append(self.stden_model.latent_feat.cpu())
mean_loss = {
'mae': np.mean(mae_losses, axis=0),
'mape': np.mean(mape_losses, axis=0),
'rmse': np.mean(rmse_losses, axis=0)
}
for i, seq in enumerate(seq_len):
self._logger.info('Evaluation seq {}: - mae - {:.4f} - mape - {:.4f} - rmse - {:.4f}'.format(
seq, mean_loss['mae'][i], mean_loss['mape'][i], mean_loss['rmse'][i]))
if(save):
# shape (horizon, num_sapmles, feat_dim)
y_preds = np.concatenate(y_preds, axis=1)
y_truths = np.concatenate(y_truths, axis=1) # concatenate on batch dimension
y_preds_scaled = self.standard_scaler.inverse_transform(y_preds)
y_truths_scaled = self.standard_scaler.inverse_transform(y_truths)
save_dir = self._data_kwargs.get('dataset_dir', 'data')
save_path = os.path.join(save_dir, 'pred_{}_{}.npz'.format(self.experimentID, self._epoch_num))
np.savez_compressed(save_path, prediction=y_preds_scaled, turth=y_truths_scaled)
if(extract_latent):
# concatenate on batch dimension
latents = np.concatenate(latents, axis=1)
# Shape of latents (horizon, num_samples, self.num_edges * self.output_dim)
save_dir = self._data_kwargs.get('dataset_dir', 'data')
filter_type = self._model_kwargs.get('filter_type', 'unknown')
save_path = os.path.join(save_dir, '{}_latent_{}_{}.npz'.format(filter_type, self.experimentID, self._epoch_num))
np.savez_compressed(save_path, latent=latents)
if bool(self._model_kwargs.get('nfe', False)):
res = pd.DataFrame(c, columns=['nfe', 'time', 'err'])
res.index.name = 'iter'
filter_type = self._model_kwargs.get('filter_type', 'unknown')
atol = float(self._model_kwargs.get('odeint_atol', 1e-5))
rtol = float(self._model_kwargs.get('odeint_rtol', 1e-5))
nfe_file = os.path.join(
self._data_kwargs.get('dataset_dir', 'data'),
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
res.to_pickle(nfe_file)

23
requirements.txt Normal file
View File

@ -0,0 +1,23 @@
# STDEN项目依赖包
# 深度学习框架
torch>=1.9.0
torchvision>=0.10.0
# 科学计算
numpy>=1.21.0
scipy>=1.7.0
pandas>=1.3.0
# 机器学习工具
scikit-learn>=1.0.0
# 可视化
matplotlib>=3.4.0
tensorboard>=2.6.0
# 配置和日志
PyYAML>=5.4.0
# 其他工具
tqdm>=4.62.0
pathlib2>=2.3.0

105
run.py Normal file
View File

@ -0,0 +1,105 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
STDEN项目主运行文件
根据模型名称自动调用对应的配置数据加载器和训练器
"""
import argparse
import yaml
import torch
import numpy as np
import os
import sys
from pathlib import Path
# 添加项目根目录到Python路径
project_root = Path(__file__).parent
sys.path.append(str(project_root))
from lib.logger import setup_logger
from lib.utils import load_graph_data
from trainer.stden_trainer import STDENTrainer
from dataloader.stden_dataloader import STDENDataloader
def load_config(config_path):
"""加载YAML配置文件"""
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return config
def setup_environment(config):
"""设置环境变量和随机种子"""
# 设置随机种子
random_seed = config.get('random_seed', 2021)
torch.manual_seed(random_seed)
np.random.seed(random_seed)
# 设置设备
device = 'cuda' if torch.cuda.is_available() and not config.get('use_cpu_only', False) else 'cpu'
config['device'] = device
return config
def main():
parser = argparse.ArgumentParser(description='STDEN项目训练和评估')
parser.add_argument('--model_name', type=str, required=True,
choices=['stde_gt', 'stde_wrs', 'stde_zgc'],
help='模型名称,对应配置文件')
parser.add_argument('--mode', type=str, default='train',
choices=['train', 'eval'],
help='运行模式:训练或评估')
parser.add_argument('--config_dir', type=str, default='configs',
help='配置文件目录')
parser.add_argument('--use_cpu_only', action='store_true',
help='仅使用CPU')
parser.add_argument('--save_pred', action='store_true',
help='保存预测结果(仅评估模式)')
args = parser.parse_args()
# 构建配置文件路径
config_path = Path(args.config_dir) / f"{args.model_name}.yaml"
if not config_path.exists():
print(f"错误:配置文件 {config_path} 不存在")
sys.exit(1)
# 加载配置
config = load_config(config_path)
config['use_cpu_only'] = args.use_cpu_only
config['mode'] = args.mode
# 设置环境
config = setup_environment(config)
# 设置日志
logger = setup_logger(config)
logger.info(f"开始运行 {args.model_name} 模型,模式:{args.mode}")
logger.info(f"使用设备:{config['device']}")
# 加载图数据
graph_pkl_filename = config['data']['graph_pkl_filename']
adj_matrix = load_graph_data(graph_pkl_filename)
config['adj_matrix'] = adj_matrix
# 创建数据加载器
dataloader = STDENDataloader(config)
# 创建训练器
trainer = STDENTrainer(config, dataloader)
# 根据模式执行
if args.mode == 'train':
trainer.train()
else: # eval mode
trainer.evaluate(save_predictions=args.save_pred)
logger.info(f"{args.model_name} 模型运行完成")
if __name__ == '__main__':
main()

9
trainer/__init__.py Normal file
View File

@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
"""
训练器模块
包含STDEN项目的训练和评估逻辑
"""
from .stden_trainer import STDENTrainer
__all__ = ['STDENTrainer']

383
trainer/stden_trainer.py Normal file
View File

@ -0,0 +1,383 @@
# -*- coding: utf-8 -*-
"""
STDEN训练器
负责模型的训练验证和评估
"""
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import logging
from pathlib import Path
import json
from model.stden_model import STDENModel
from lib.metrics import masked_mae_loss, masked_mape_loss, masked_rmse_loss
class STDENTrainer:
"""STDEN训练器主类"""
def __init__(self, config, dataloader):
"""
初始化训练器
Args:
config: 配置字典
dataloader: 数据加载器实例
"""
self.config = config
self.dataloader = dataloader
self.logger = logging.getLogger('STDEN')
# 设置设备
self.device = torch.device(config['device'])
# 模型配置
self.model_config = config['model']
self.train_config = config['train']
# 创建模型
self.model = self._create_model()
# 创建优化器和学习率调度器
self.optimizer, self.scheduler = self._create_optimizer()
# 设置损失函数
self.criterion = masked_mae_loss
# 创建检查点目录
self.checkpoint_dir = self._create_checkpoint_dir()
# 创建TensorBoard写入器
self.writer = self._create_tensorboard_writer()
# 训练状态
self.current_epoch = 0
self.best_val_loss = float('inf')
self.patience_counter = 0
self.logger.info("训练器初始化完成")
def _create_model(self):
"""创建STDEN模型"""
# 获取邻接矩阵
adj_matrix = self.config['adj_matrix']
# 创建模型
model = STDENModel(
adj_matrix=adj_matrix,
logger=self.logger,
**self.model_config
)
# 移动到指定设备
model = model.to(self.device)
self.logger.info(f"模型创建完成,参数数量: {sum(p.numel() for p in model.parameters())}")
return model
def _create_optimizer(self):
"""创建优化器和学习率调度器"""
# 获取训练配置
base_lr = self.train_config['base_lr']
optimizer_name = self.train_config.get('optimizer', 'adam').lower()
# 创建优化器
if optimizer_name == 'adam':
optimizer = optim.Adam(self.model.parameters(), lr=base_lr)
elif optimizer_name == 'sgd':
optimizer = optim.SGD(self.model.parameters(), lr=base_lr)
else:
raise ValueError(f"不支持的优化器: {optimizer_name}")
# 创建学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=self.train_config.get('lr_decay_ratio', 0.1),
patience=self.train_config.get('patience', 20),
min_lr=self.train_config.get('min_learning_rate', 1e-6),
verbose=True
)
return optimizer, scheduler
def _create_checkpoint_dir(self):
"""创建检查点目录"""
checkpoint_dir = Path("checkpoints") / f"experiment_{int(time.time())}"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
# 保存配置
config_path = checkpoint_dir / "config.json"
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(self.config, f, indent=2, ensure_ascii=False)
return checkpoint_dir
def _create_tensorboard_writer(self):
"""创建TensorBoard写入器"""
log_dir = Path("runs") / f"experiment_{int(time.time())}"
return SummaryWriter(str(log_dir))
def train(self):
"""训练模型"""
self.logger.info("开始训练模型")
# 获取训练配置
epochs = self.train_config['epochs']
patience = self.train_config['patience']
test_every_n_epochs = self.train_config.get('test_every_n_epochs', 5)
# 获取数据加载器
train_loader = self.dataloader.train_loader
val_loader = self.dataloader.val_loader
for epoch in range(epochs):
self.current_epoch = epoch
# 训练一个epoch
train_loss = self._train_epoch(train_loader)
# 验证
val_loss = self._validate_epoch(val_loader)
# 更新学习率
self.scheduler.step(val_loss)
# 记录到TensorBoard
self.writer.add_scalar('Loss/Train', train_loss, epoch)
self.writer.add_scalar('Loss/Validation', val_loss, epoch)
self.writer.add_scalar('Learning_Rate', self.optimizer.param_groups[0]['lr'], epoch)
# 打印训练信息
self.logger.info(
f"Epoch {epoch+1}/{epochs} - "
f"Train Loss: {train_loss:.6f}, "
f"Val Loss: {val_loss:.6f}, "
f"LR: {self.optimizer.param_groups[0]['lr']:.6f}"
)
# 保存最佳模型
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.patience_counter = 0
self._save_checkpoint(epoch, is_best=True)
self.logger.info(f"新的最佳验证损失: {val_loss:.6f}")
else:
self.patience_counter += 1
# 定期保存检查点
if (epoch + 1) % 10 == 0:
self._save_checkpoint(epoch, is_best=False)
# 定期测试
if (epoch + 1) % test_every_n_epochs == 0:
test_metrics = self._test_epoch()
self.logger.info(f"测试指标: {test_metrics}")
# 早停检查
if self.patience_counter >= patience:
self.logger.info(f"早停触发在epoch {epoch+1}")
break
# 训练完成
self.logger.info("训练完成")
self.writer.close()
# 最终测试
final_test_metrics = self._test_epoch()
self.logger.info(f"最终测试指标: {final_test_metrics}")
def _train_epoch(self, train_loader):
"""训练一个epoch"""
self.model.train()
total_loss = 0.0
num_batches = 0
for batch_idx, (x, y) in enumerate(train_loader):
# 移动数据到设备
x = x.to(self.device)
y = y.to(self.device)
# 前向传播
self.optimizer.zero_grad()
output = self.model(x)
# 计算损失
loss = self.criterion(output, y)
# 反向传播
loss.backward()
# 梯度裁剪
max_grad_norm = self.train_config.get('max_grad_norm', 5.0)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
# 更新参数
self.optimizer.step()
total_loss += loss.item()
num_batches += 1
return total_loss / num_batches
def _validate_epoch(self, val_loader):
"""验证一个epoch"""
self.model.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for x, y in val_loader:
x = x.to(self.device)
y = y.to(self.device)
output = self.model(x)
loss = self.criterion(output, y)
total_loss += loss.item()
num_batches += 1
return total_loss / num_batches
def _test_epoch(self):
"""测试模型"""
self.model.eval()
test_loader = self.dataloader.test_loader
total_mae = 0.0
total_mape = 0.0
total_rmse = 0.0
num_batches = 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(self.device)
y = y.to(self.device)
output = self.model(x)
# 计算各种指标
mae = masked_mae_loss(output, y)
mape = masked_mape_loss(output, y)
rmse = masked_rmse_loss(output, y)
total_mae += mae.item()
total_mape += mape.item()
total_rmse += rmse.item()
num_batches += 1
metrics = {
'MAE': total_mae / num_batches,
'MAPE': total_mape / num_batches,
'RMSE': total_rmse / num_batches
}
return metrics
def _save_checkpoint(self, epoch, is_best=False):
"""保存检查点"""
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'best_val_loss': self.best_val_loss,
'config': self.config
}
# 保存最新检查点
latest_path = self.checkpoint_dir / "latest.pth"
torch.save(checkpoint, latest_path)
# 保存最佳检查点
if is_best:
best_path = self.checkpoint_dir / "best.pth"
torch.save(checkpoint, best_path)
# 保存特定epoch的检查点
epoch_path = self.checkpoint_dir / f"epoch_{epoch+1}.pth"
torch.save(checkpoint, epoch_path)
self.logger.info(f"检查点已保存: {epoch_path}")
def load_checkpoint(self, checkpoint_path):
"""加载检查点"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.current_epoch = checkpoint['epoch']
self.best_val_loss = checkpoint['best_val_loss']
self.logger.info(f"检查点已加载: {checkpoint_path}")
def evaluate(self, save_predictions=False):
"""评估模型"""
self.logger.info("开始评估模型")
# 加载最佳模型
best_checkpoint_path = self.checkpoint_dir / "best.pth"
if best_checkpoint_path.exists():
self.load_checkpoint(best_checkpoint_path)
else:
self.logger.warning("未找到最佳检查点,使用当前模型")
# 测试
test_metrics = self._test_epoch()
# 打印结果
self.logger.info("评估结果:")
for metric_name, metric_value in test_metrics.items():
self.logger.info(f" {metric_name}: {metric_value:.6f}")
# 保存预测结果
if save_predictions:
self._save_predictions()
return test_metrics
def _save_predictions(self):
"""保存预测结果"""
self.model.eval()
test_loader = self.dataloader.test_loader
predictions = []
targets = []
with torch.no_grad():
for x, y in test_loader:
x = x.to(self.device)
output = self.model(x)
# 反标准化
scaler = self.dataloader.get_scaler()
output_denorm = scaler.inverse_transform(output.cpu().numpy())
y_denorm = scaler.inverse_transform(y.numpy())
predictions.append(output_denorm)
targets.append(y_denorm)
# 合并所有批次
predictions = np.concatenate(predictions, axis=0)
targets = np.concatenate(targets, axis=0)
# 保存到文件
results_dir = self.checkpoint_dir / "results"
results_dir.mkdir(exist_ok=True)
np.save(results_dir / "predictions.npy", predictions)
np.save(results_dir / "targets.npy", targets)
self.logger.info(f"预测结果已保存到: {results_dir}")
def __del__(self):
"""析构函数,确保资源被正确释放"""
if hasattr(self, 'writer'):
self.writer.close()