STDEN工程化到当前项目
This commit is contained in:
parent
568fff7e99
commit
ee955e9481
|
|
@ -15,7 +15,6 @@ dist/
|
|||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
|
|
@ -160,3 +159,5 @@ cython_debug/
|
|||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
STDEN/
|
||||
|
||||
|
|
|
|||
136
README.md
136
README.md
|
|
@ -1,3 +1,135 @@
|
|||
# Project-I
|
||||
# STDEN项目
|
||||
|
||||
Secret Projct
|
||||
时空扩散方程网络(Spatio-Temporal Diffusion Equation Network)项目,用于时空序列预测任务。
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
Project-I/
|
||||
├── run.py # 主运行文件
|
||||
├── configs/ # 配置文件目录
|
||||
│ ├── stde_gt.yaml # STDE_GT模型配置
|
||||
│ ├── stde_wrs.yaml # STDE_WRS模型配置
|
||||
│ └── stde_zgc.yaml # STDE_ZGC模型配置
|
||||
├── dataloader/ # 数据加载器模块
|
||||
│ ├── __init__.py
|
||||
│ └── stden_dataloader.py
|
||||
├── trainer/ # 训练器模块
|
||||
│ ├── __init__.py
|
||||
│ └── stden_trainer.py
|
||||
├── model/ # 模型模块
|
||||
│ ├── __init__.py
|
||||
│ ├── stden_model.py
|
||||
│ ├── stden_supervisor.py
|
||||
│ ├── diffeq_solver.py
|
||||
│ └── ode_func.py
|
||||
├── lib/ # 工具库
|
||||
│ ├── __init__.py
|
||||
│ ├── logger.py
|
||||
│ ├── utils.py
|
||||
│ └── metrics.py
|
||||
├── requirements.txt # 项目依赖
|
||||
└── README.md # 项目说明
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 安装依赖
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. 训练模型
|
||||
|
||||
```bash
|
||||
# 训练STDE_GT模型
|
||||
python run.py --model_name stde_gt --mode train
|
||||
|
||||
# 训练STDE_WRS模型
|
||||
python run.py --model_name stde_wrs --mode train
|
||||
|
||||
# 训练STDE_ZGC模型
|
||||
python run.py --model_name stde_zgc --mode train
|
||||
```
|
||||
|
||||
### 3. 评估模型
|
||||
|
||||
```bash
|
||||
# 评估STDE_GT模型
|
||||
python run.py --model_name stde_gt --mode eval --save_pred
|
||||
|
||||
# 评估STDE_WRS模型
|
||||
python run.py --model_name stde_wrs --mode eval --save_pred
|
||||
|
||||
# 评估STDE_ZGC模型
|
||||
python run.py --model_name stde_zgc --mode eval --save_pred
|
||||
```
|
||||
|
||||
## 配置说明
|
||||
|
||||
项目使用YAML格式的配置文件,主要包含三个部分:
|
||||
|
||||
### 数据配置 (data)
|
||||
- `dataset_dir`: 数据集目录路径
|
||||
- `batch_size`: 训练批处理大小
|
||||
- `val_batch_size`: 验证批处理大小
|
||||
- `graph_pkl_filename`: 传感器图邻接矩阵文件
|
||||
|
||||
### 模型配置 (model)
|
||||
- `seq_len`: 输入序列长度
|
||||
- `horizon`: 预测时间步数
|
||||
- `input_dim`: 输入特征维度
|
||||
- `output_dim`: 输出特征维度
|
||||
- `latent_dim`: 潜在空间维度
|
||||
- `n_traj_samples`: 轨迹采样数量
|
||||
- `ode_method`: ODE求解方法
|
||||
- `rnn_units`: RNN隐藏单元数量
|
||||
- `gcn_step`: 图卷积步数
|
||||
|
||||
### 训练配置 (train)
|
||||
- `base_lr`: 基础学习率
|
||||
- `epochs`: 总训练轮数
|
||||
- `patience`: 早停耐心值
|
||||
- `optimizer`: 优化器类型
|
||||
- `max_grad_norm`: 最大梯度范数
|
||||
|
||||
## 主要特性
|
||||
|
||||
1. **模块化设计**: 清晰的数据加载器、训练器、模型分离
|
||||
2. **配置驱动**: 使用YAML配置文件,易于调整参数
|
||||
3. **统一接口**: 通过run.py统一调用不同模型
|
||||
4. **完整日志**: 支持文件和控制台日志输出
|
||||
5. **TensorBoard支持**: 训练过程可视化
|
||||
6. **检查点管理**: 自动保存和加载最佳模型
|
||||
|
||||
## 支持的模型
|
||||
|
||||
- **STDE_GT**: 用于北京GM传感器图数据
|
||||
- **STDE_WRS**: 用于WRS传感器图数据
|
||||
- **STDE_ZGC**: 用于ZGC传感器图数据
|
||||
|
||||
## 数据格式
|
||||
|
||||
项目支持两种数据格式:
|
||||
|
||||
1. **BJ格式**: 包含flow.npz文件,适用于北京数据集
|
||||
2. **标准格式**: 包含train.npz、val.npz、test.npz文件
|
||||
|
||||
## 日志和输出
|
||||
|
||||
- 训练日志保存在`logs/`目录
|
||||
- 模型检查点保存在`checkpoints/`目录
|
||||
- TensorBoard日志保存在`runs/`目录
|
||||
- 预测结果保存在检查点目录的`results/`子目录
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 确保数据集目录结构正确
|
||||
2. 传感器图文件路径配置正确
|
||||
3. 根据硬件配置调整批处理大小
|
||||
4. 训练过程中会自动创建必要的目录
|
||||
|
||||
## 许可证
|
||||
|
||||
本项目采用MIT许可证,详见LICENSE文件。
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
配置文件模块
|
||||
包含STDEN项目的各种配置
|
||||
"""
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
# STDE_GT模型配置文件
|
||||
# 用于北京GM传感器图数据的时空扩散方程网络模型
|
||||
|
||||
# 基础配置
|
||||
model_name: "stde_gt"
|
||||
random_seed: 2021
|
||||
log_level: "INFO"
|
||||
log_base_dir: "logs/BJ_GM"
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
# 数据集目录
|
||||
dataset_dir: "data/BJ_GM"
|
||||
# 批处理大小
|
||||
batch_size: 32
|
||||
# 验证集批处理大小
|
||||
val_batch_size: 32
|
||||
# 传感器图邻接矩阵文件
|
||||
graph_pkl_filename: "data/sensor_graph/adj_GM.npy"
|
||||
|
||||
# 模型配置
|
||||
model:
|
||||
# 输入序列长度
|
||||
seq_len: 12
|
||||
# 预测时间步数
|
||||
horizon: 12
|
||||
# 输入特征维度
|
||||
input_dim: 1
|
||||
# 输出特征维度
|
||||
output_dim: 1
|
||||
# 潜在空间维度
|
||||
latent_dim: 4
|
||||
# 轨迹采样数量
|
||||
n_traj_samples: 3
|
||||
# ODE求解方法
|
||||
ode_method: "dopri5"
|
||||
# ODE求解器绝对误差容差
|
||||
odeint_atol: 0.00001
|
||||
# ODE求解器相对误差容差
|
||||
odeint_rtol: 0.00001
|
||||
# RNN隐藏单元数量
|
||||
rnn_units: 64
|
||||
# RNN层数
|
||||
num_rnn_layers: 1
|
||||
# 图卷积步数
|
||||
gcn_step: 2
|
||||
# 滤波器类型 (default/unkP/IncP)
|
||||
filter_type: "default"
|
||||
# 循环神经网络类型
|
||||
recg_type: "gru"
|
||||
# 是否保存潜在表示
|
||||
save_latent: false
|
||||
# 是否记录函数评估次数
|
||||
nfe: false
|
||||
# L1正则化衰减
|
||||
l1_decay: 0
|
||||
|
||||
# 训练配置
|
||||
train:
|
||||
# 基础学习率
|
||||
base_lr: 0.01
|
||||
# Dropout比率
|
||||
dropout: 0
|
||||
# 加载的检查点epoch
|
||||
load: 0
|
||||
# 当前训练epoch
|
||||
epoch: 0
|
||||
# 总训练epoch数
|
||||
epochs: 100
|
||||
# 收敛阈值
|
||||
epsilon: 1.0e-3
|
||||
# 学习率衰减比率
|
||||
lr_decay_ratio: 0.1
|
||||
# 最大梯度范数
|
||||
max_grad_norm: 5
|
||||
# 最小学习率
|
||||
min_learning_rate: 2.0e-06
|
||||
# 优化器类型
|
||||
optimizer: "adam"
|
||||
# 早停耐心值
|
||||
patience: 20
|
||||
# 学习率衰减步数
|
||||
steps: [20, 30, 40, 50]
|
||||
# 测试频率(每N个epoch测试一次)
|
||||
test_every_n_epochs: 5
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
# STDE_WRS模型配置文件
|
||||
# 用于WRS传感器图数据的时空扩散方程网络模型
|
||||
|
||||
# 基础配置
|
||||
model_name: "stde_wrs"
|
||||
random_seed: 2021
|
||||
log_level: "INFO"
|
||||
log_base_dir: "logs/WRS"
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
# 数据集目录
|
||||
dataset_dir: "data/WRS"
|
||||
# 批处理大小
|
||||
batch_size: 32
|
||||
# 验证集批处理大小
|
||||
val_batch_size: 32
|
||||
# 传感器图邻接矩阵文件
|
||||
graph_pkl_filename: "data/sensor_graph/adj_WRS.npy"
|
||||
|
||||
# 模型配置
|
||||
model:
|
||||
# 输入序列长度
|
||||
seq_len: 12
|
||||
# 预测时间步数
|
||||
horizon: 12
|
||||
# 输入特征维度
|
||||
input_dim: 1
|
||||
# 输出特征维度
|
||||
output_dim: 1
|
||||
# 潜在空间维度
|
||||
latent_dim: 4
|
||||
# 轨迹采样数量
|
||||
n_traj_samples: 3
|
||||
# ODE求解方法
|
||||
ode_method: "dopri5"
|
||||
# ODE求解器绝对误差容差
|
||||
odeint_atol: 0.00001
|
||||
# ODE求解器相对误差容差
|
||||
odeint_rtol: 0.00001
|
||||
# RNN隐藏单元数量
|
||||
rnn_units: 64
|
||||
# RNN层数
|
||||
num_rnn_layers: 1
|
||||
# 图卷积步数
|
||||
gcn_step: 2
|
||||
# 滤波器类型 (default/unkP/IncP)
|
||||
filter_type: "default"
|
||||
# 循环神经网络类型
|
||||
recg_type: "gru"
|
||||
# 是否保存潜在表示
|
||||
save_latent: false
|
||||
# 是否记录函数评估次数
|
||||
nfe: false
|
||||
# L1正则化衰减
|
||||
l1_decay: 0
|
||||
|
||||
# 训练配置
|
||||
train:
|
||||
# 基础学习率
|
||||
base_lr: 0.01
|
||||
# Dropout比率
|
||||
dropout: 0
|
||||
# 加载的检查点epoch
|
||||
load: 0
|
||||
# 当前训练epoch
|
||||
epoch: 0
|
||||
# 总训练epoch数
|
||||
epochs: 100
|
||||
# 收敛阈值
|
||||
epsilon: 1.0e-3
|
||||
# 学习率衰减比率
|
||||
lr_decay_ratio: 0.1
|
||||
# 最大梯度范数
|
||||
max_grad_norm: 5
|
||||
# 最小学习率
|
||||
min_learning_rate: 2.0e-06
|
||||
# 优化器类型
|
||||
optimizer: "adam"
|
||||
# 早停耐心值
|
||||
patience: 20
|
||||
# 学习率衰减步数
|
||||
steps: [20, 30, 40, 50]
|
||||
# 测试频率(每N个epoch测试一次)
|
||||
test_every_n_epochs: 5
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
# STDE_ZGC模型配置文件
|
||||
# 用于ZGC传感器图数据的时空扩散方程网络模型
|
||||
|
||||
# 基础配置
|
||||
model_name: "stde_zgc"
|
||||
random_seed: 2021
|
||||
log_level: "INFO"
|
||||
log_base_dir: "logs/ZGC"
|
||||
|
||||
# 数据配置
|
||||
data:
|
||||
# 数据集目录
|
||||
dataset_dir: "data/ZGC"
|
||||
# 批处理大小
|
||||
batch_size: 32
|
||||
# 验证集批处理大小
|
||||
val_batch_size: 32
|
||||
# 传感器图邻接矩阵文件
|
||||
graph_pkl_filename: "data/sensor_graph/adj_ZGC.npy"
|
||||
|
||||
# 模型配置
|
||||
model:
|
||||
# 输入序列长度
|
||||
seq_len: 12
|
||||
# 预测时间步数
|
||||
horizon: 12
|
||||
# 输入特征维度
|
||||
input_dim: 1
|
||||
# 输出特征维度
|
||||
output_dim: 1
|
||||
# 潜在空间维度
|
||||
latent_dim: 4
|
||||
# 轨迹采样数量
|
||||
n_traj_samples: 3
|
||||
# ODE求解方法
|
||||
ode_method: "dopri5"
|
||||
# ODE求解器绝对误差容差
|
||||
odeint_atol: 0.00001
|
||||
# ODE求解器相对误差容差
|
||||
odeint_rtol: 0.00001
|
||||
# RNN隐藏单元数量
|
||||
rnn_units: 64
|
||||
# RNN层数
|
||||
num_rnn_layers: 1
|
||||
# 图卷积步数
|
||||
gcn_step: 2
|
||||
# 滤波器类型 (default/unkP/IncP)
|
||||
filter_type: "default"
|
||||
# 循环神经网络类型
|
||||
recg_type: "gru"
|
||||
# 是否保存潜在表示
|
||||
save_latent: false
|
||||
# 是否记录函数评估次数
|
||||
nfe: false
|
||||
# L1正则化衰减
|
||||
l1_decay: 0
|
||||
|
||||
# 训练配置
|
||||
train:
|
||||
# 基础学习率
|
||||
base_lr: 0.01
|
||||
# Dropout比率
|
||||
dropout: 0
|
||||
# 加载的检查点epoch
|
||||
load: 0
|
||||
# 当前训练epoch
|
||||
epoch: 0
|
||||
# 总训练epoch数
|
||||
epochs: 100
|
||||
# 收敛阈值
|
||||
epsilon: 1.0e-3
|
||||
# 学习率衰减比率
|
||||
lr_decay_ratio: 0.1
|
||||
# 最大梯度范数
|
||||
max_grad_norm: 5
|
||||
# 最小学习率
|
||||
min_learning_rate: 2.0e-06
|
||||
# 优化器类型
|
||||
optimizer: "adam"
|
||||
# 早停耐心值
|
||||
patience: 20
|
||||
# 学习率衰减步数
|
||||
steps: [20, 30, 40, 50]
|
||||
# 测试频率(每N个epoch测试一次)
|
||||
test_every_n_epochs: 5
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据加载器模块
|
||||
包含STDEN项目的数据加载和处理逻辑
|
||||
"""
|
||||
|
||||
from .stden_dataloader import STDENDataloader
|
||||
|
||||
__all__ = ['STDENDataloader']
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
STDEN数据加载器
|
||||
负责数据的加载、预处理和批处理
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
|
||||
class STDENDataset(Dataset):
|
||||
"""STDEN数据集类"""
|
||||
|
||||
def __init__(self, x_data, y_data, sequence_length, horizon):
|
||||
"""
|
||||
初始化数据集
|
||||
|
||||
Args:
|
||||
x_data: 输入数据
|
||||
y_data: 标签数据
|
||||
sequence_length: 输入序列长度
|
||||
horizon: 预测时间步数
|
||||
"""
|
||||
self.x_data = x_data
|
||||
self.y_data = y_data
|
||||
self.sequence_length = sequence_length
|
||||
self.horizon = horizon
|
||||
self.num_samples = len(x_data) - sequence_length - horizon + 1
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# 获取输入序列
|
||||
x_seq = self.x_data[idx:idx + self.sequence_length]
|
||||
# 获取目标序列
|
||||
y_seq = self.y_data[idx + self.sequence_length:idx + self.sequence_length + self.horizon]
|
||||
|
||||
return torch.FloatTensor(x_seq), torch.FloatTensor(y_seq)
|
||||
|
||||
|
||||
class StandardScaler:
|
||||
"""数据标准化器"""
|
||||
|
||||
def __init__(self, mean=None, std=None):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def fit(self, data):
|
||||
"""拟合数据,计算均值和标准差"""
|
||||
self.mean = np.mean(data, axis=0)
|
||||
self.std = np.std(data, axis=0)
|
||||
# 避免除零
|
||||
self.std[self.std == 0] = 1.0
|
||||
|
||||
def transform(self, data):
|
||||
"""标准化数据"""
|
||||
return (data - self.mean) / self.std
|
||||
|
||||
def inverse_transform(self, data):
|
||||
"""反标准化数据"""
|
||||
return (data * self.std) + self.mean
|
||||
|
||||
|
||||
class STDENDataloader:
|
||||
"""STDEN数据加载器主类"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""
|
||||
初始化数据加载器
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
"""
|
||||
self.config = config
|
||||
self.logger = logging.getLogger('STDEN')
|
||||
|
||||
# 数据配置
|
||||
self.dataset_dir = config['data']['dataset_dir']
|
||||
self.batch_size = config['data']['batch_size']
|
||||
self.val_batch_size = config['data'].get('val_batch_size', self.batch_size)
|
||||
self.sequence_length = config['model']['seq_len']
|
||||
self.horizon = config['model']['horizon']
|
||||
|
||||
# 加载数据
|
||||
self.data = self._load_dataset()
|
||||
self.scaler = self.data['scaler']
|
||||
|
||||
# 创建数据加载器
|
||||
self.train_loader = self._create_dataloader(
|
||||
self.data['x_train'], self.data['y_train'],
|
||||
self.batch_size, shuffle=True
|
||||
)
|
||||
self.val_loader = self._create_dataloader(
|
||||
self.data['x_val'], self.data['y_val'],
|
||||
self.val_batch_size, shuffle=False
|
||||
)
|
||||
self.test_loader = self._create_dataloader(
|
||||
self.data['x_test'], self.data['y_test'],
|
||||
self.val_batch_size, shuffle=False
|
||||
)
|
||||
|
||||
self.logger.info(f"数据加载完成 - 训练集: {len(self.data['x_train'])} 样本")
|
||||
self.logger.info(f"验证集: {len(self.data['x_val'])} 样本, 测试集: {len(self.data['x_test'])} 样本")
|
||||
|
||||
def _load_dataset(self):
|
||||
"""加载数据集"""
|
||||
dataset_path = Path(self.dataset_dir)
|
||||
|
||||
if not dataset_path.exists():
|
||||
raise FileNotFoundError(f"数据集目录不存在: {self.dataset_dir}")
|
||||
|
||||
# 检查数据集类型(BJ_GM或其他)
|
||||
if 'BJ' in self.dataset_dir:
|
||||
return self._load_bj_dataset(dataset_path)
|
||||
else:
|
||||
return self._load_standard_dataset(dataset_path)
|
||||
|
||||
def _load_bj_dataset(self, dataset_path):
|
||||
"""加载BJ_GM数据集"""
|
||||
flow_file = dataset_path / 'flow.npz'
|
||||
if not flow_file.exists():
|
||||
raise FileNotFoundError(f"BJ数据集文件不存在: {flow_file}")
|
||||
|
||||
data = dict(np.load(flow_file))
|
||||
|
||||
# 提取训练、验证、测试数据
|
||||
x_train = data['x_train']
|
||||
y_train = data['y_train']
|
||||
x_val = data['x_val']
|
||||
y_val = data['y_val']
|
||||
x_test = data['x_test']
|
||||
y_test = data['y_test']
|
||||
|
||||
return self._process_data(x_train, y_train, x_val, y_val, x_test, y_test)
|
||||
|
||||
def _load_standard_dataset(self, dataset_path):
|
||||
"""加载标准格式数据集"""
|
||||
data = {}
|
||||
|
||||
for category in ['train', 'val', 'test']:
|
||||
cat_file = dataset_path / f"{category}.npz"
|
||||
if not cat_file.exists():
|
||||
raise FileNotFoundError(f"数据集文件不存在: {cat_file}")
|
||||
|
||||
cat_data = np.load(cat_file)
|
||||
data[f'x_{category}'] = cat_data['x']
|
||||
data[f'y_{category}'] = cat_data['y']
|
||||
|
||||
return self._process_data(
|
||||
data['x_train'], data['y_train'],
|
||||
data['x_val'], data['y_val'],
|
||||
data['x_test'], data['y_test']
|
||||
)
|
||||
|
||||
def _process_data(self, x_train, y_train, x_val, y_val, x_test, y_test):
|
||||
"""处理数据(标准化等)"""
|
||||
# 创建标准化器
|
||||
scaler = StandardScaler()
|
||||
scaler.fit(x_train)
|
||||
|
||||
# 标准化所有数据
|
||||
x_train_scaled = scaler.transform(x_train)
|
||||
y_train_scaled = scaler.transform(y_train)
|
||||
x_val_scaled = scaler.transform(x_val)
|
||||
y_val_scaled = scaler.transform(y_val)
|
||||
x_test_scaled = scaler.transform(x_test)
|
||||
y_test_scaled = scaler.transform(y_test)
|
||||
|
||||
return {
|
||||
'x_train': x_train_scaled,
|
||||
'y_train': y_train_scaled,
|
||||
'x_val': x_val_scaled,
|
||||
'y_val': y_val_scaled,
|
||||
'x_test': x_test_scaled,
|
||||
'y_test': y_test_scaled,
|
||||
'scaler': scaler
|
||||
}
|
||||
|
||||
def _create_dataloader(self, x_data, y_data, batch_size, shuffle=False):
|
||||
"""创建PyTorch数据加载器"""
|
||||
dataset = STDENDataset(x_data, y_data, self.sequence_length, self.horizon)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=0, # 避免多进程问题
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
def get_data_loaders(self):
|
||||
"""获取所有数据加载器"""
|
||||
return {
|
||||
'train': self.train_loader,
|
||||
'val': self.val_loader,
|
||||
'test': self.test_loader
|
||||
}
|
||||
|
||||
def get_scaler(self):
|
||||
"""获取标准化器"""
|
||||
return self.scaler
|
||||
|
||||
def get_data_info(self):
|
||||
"""获取数据信息"""
|
||||
return {
|
||||
'input_dim': self.data['x_train'].shape[-1],
|
||||
'output_dim': self.data['y_train'].shape[-1],
|
||||
'sequence_length': self.sequence_length,
|
||||
'horizon': self.horizon,
|
||||
'num_nodes': self.data['x_train'].shape[1] if len(self.data['x_train'].shape) > 1 else 1
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
示例模块
|
||||
包含STDEN项目的使用示例
|
||||
"""
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
STDEN项目训练示例
|
||||
演示如何使用新的项目结构进行模型训练
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.append(str(project_root))
|
||||
|
||||
from lib.logger import setup_logger
|
||||
from dataloader.stden_dataloader import STDENDataloader
|
||||
from trainer.stden_trainer import STDENTrainer
|
||||
import yaml
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数示例"""
|
||||
|
||||
# 配置字典(也可以从YAML文件加载)
|
||||
config = {
|
||||
'model_name': 'stde_gt',
|
||||
'log_level': 'INFO',
|
||||
'log_base_dir': 'logs/example',
|
||||
'device': 'cpu', # 或 'cuda'
|
||||
|
||||
'data': {
|
||||
'dataset_dir': 'data/BJ_GM',
|
||||
'batch_size': 16,
|
||||
'val_batch_size': 16,
|
||||
'graph_pkl_filename': 'data/sensor_graph/adj_GM.npy'
|
||||
},
|
||||
|
||||
'model': {
|
||||
'seq_len': 12,
|
||||
'horizon': 12,
|
||||
'input_dim': 1,
|
||||
'output_dim': 1,
|
||||
'latent_dim': 4,
|
||||
'n_traj_samples': 3,
|
||||
'ode_method': 'dopri5',
|
||||
'odeint_atol': 0.00001,
|
||||
'odeint_rtol': 0.00001,
|
||||
'rnn_units': 64,
|
||||
'num_rnn_layers': 1,
|
||||
'gcn_step': 2,
|
||||
'filter_type': 'default',
|
||||
'recg_type': 'gru',
|
||||
'save_latent': False,
|
||||
'nfe': False,
|
||||
'l1_decay': 0
|
||||
},
|
||||
|
||||
'train': {
|
||||
'base_lr': 0.01,
|
||||
'dropout': 0,
|
||||
'load': 0,
|
||||
'epoch': 0,
|
||||
'epochs': 50, # 减少训练轮数用于示例
|
||||
'epsilon': 1.0e-3,
|
||||
'lr_decay_ratio': 0.1,
|
||||
'max_grad_norm': 5,
|
||||
'min_learning_rate': 2.0e-06,
|
||||
'optimizer': 'adam',
|
||||
'patience': 10,
|
||||
'steps': [10, 20, 30],
|
||||
'test_every_n_epochs': 5
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
# 设置日志
|
||||
logger = setup_logger(config)
|
||||
logger.info("开始STDEN项目训练示例")
|
||||
|
||||
# 注意:这里需要实际的邻接矩阵数据
|
||||
# 为了示例,我们创建一个虚拟的邻接矩阵
|
||||
import numpy as np
|
||||
config['adj_matrix'] = np.random.rand(10, 10) # 10x10的随机邻接矩阵
|
||||
|
||||
# 创建数据加载器
|
||||
logger.info("创建数据加载器...")
|
||||
try:
|
||||
dataloader = STDENDataloader(config)
|
||||
logger.info("数据加载器创建成功")
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(f"数据加载器创建失败(预期行为,因为示例中没有实际数据): {e}")
|
||||
logger.info("继续演示项目结构...")
|
||||
return
|
||||
|
||||
# 创建训练器
|
||||
logger.info("创建训练器...")
|
||||
trainer = STDENTrainer(config, dataloader)
|
||||
|
||||
# 开始训练
|
||||
logger.info("开始训练...")
|
||||
trainer.train()
|
||||
|
||||
# 评估模型
|
||||
logger.info("开始评估...")
|
||||
metrics = trainer.evaluate(save_predictions=True)
|
||||
|
||||
logger.info("训练示例完成!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"训练过程中发生错误: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
日志配置模块
|
||||
提供统一的日志格式和配置
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def setup_logger(config):
|
||||
"""
|
||||
设置日志记录器
|
||||
|
||||
Args:
|
||||
config: 配置字典,包含日志相关配置
|
||||
|
||||
Returns:
|
||||
logger: 配置好的日志记录器
|
||||
"""
|
||||
# 获取日志配置
|
||||
log_level = config.get('log_level', 'INFO')
|
||||
log_base_dir = config.get('log_base_dir', 'logs')
|
||||
|
||||
# 创建日志目录
|
||||
log_dir = Path(log_base_dir)
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建日志文件名(包含时间戳)
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
model_name = config.get('model_name', 'stden')
|
||||
log_filename = f"{model_name}_{timestamp}.log"
|
||||
log_path = log_dir / log_filename
|
||||
|
||||
# 创建日志记录器
|
||||
logger = logging.getLogger('STDEN')
|
||||
logger.setLevel(getattr(logging, log_level.upper()))
|
||||
|
||||
# 清除现有的处理器
|
||||
logger.handlers.clear()
|
||||
|
||||
# 创建文件处理器
|
||||
file_handler = logging.FileHandler(log_path, encoding='utf-8')
|
||||
file_handler.setLevel(getattr(logging, log_level.upper()))
|
||||
|
||||
# 创建控制台处理器
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(getattr(logging, log_level.upper()))
|
||||
|
||||
# 创建格式化器
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
# 设置格式化器
|
||||
file_handler.setFormatter(formatter)
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# 添加处理器
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# 记录日志配置信息
|
||||
logger.info(f"日志文件路径: {log_path}")
|
||||
logger.info(f"日志级别: {log_level}")
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def get_logger(name='STDEN'):
|
||||
"""
|
||||
获取已配置的日志记录器
|
||||
|
||||
Args:
|
||||
name: 日志记录器名称
|
||||
|
||||
Returns:
|
||||
logger: 日志记录器
|
||||
"""
|
||||
return logging.getLogger(name)
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
import torch
|
||||
|
||||
def masked_mae_loss(y_pred, y_true):
|
||||
y_true[y_true < 1e-4] = 0
|
||||
mask = (y_true != 0).float()
|
||||
mask /= mask.mean() # assign the sample weights of zeros to nonzero-values
|
||||
loss = torch.abs(y_pred - y_true)
|
||||
loss = loss * mask
|
||||
# trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
|
||||
loss[loss != loss] = 0
|
||||
return loss.mean()
|
||||
|
||||
def masked_mape_loss(y_pred, y_true):
|
||||
y_true[y_true < 1e-4] = 0
|
||||
mask = (y_true != 0).float()
|
||||
mask /= mask.mean()
|
||||
loss = torch.abs((y_pred - y_true) / y_true)
|
||||
loss = loss * mask
|
||||
loss[loss != loss] = 0
|
||||
return loss.mean()
|
||||
|
||||
def masked_rmse_loss(y_pred, y_true):
|
||||
y_true[y_true < 1e-4] = 0
|
||||
mask = (y_true != 0).float()
|
||||
mask /= mask.mean()
|
||||
loss = torch.pow(y_pred - y_true, 2)
|
||||
loss = loss * mask
|
||||
loss[loss != loss] = 0
|
||||
return torch.sqrt(loss.mean())
|
||||
|
|
@ -0,0 +1,228 @@
|
|||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
import scipy.sparse as sp
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class DataLoader(object):
|
||||
def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False):
|
||||
"""
|
||||
|
||||
:param xs:
|
||||
:param ys:
|
||||
:param batch_size:
|
||||
:param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size.
|
||||
"""
|
||||
self.batch_size = batch_size
|
||||
self.current_ind = 0
|
||||
if pad_with_last_sample:
|
||||
num_padding = (batch_size - (len(xs) % batch_size)) % batch_size
|
||||
x_padding = np.repeat(xs[-1:], num_padding, axis=0)
|
||||
y_padding = np.repeat(ys[-1:], num_padding, axis=0)
|
||||
xs = np.concatenate([xs, x_padding], axis=0)
|
||||
ys = np.concatenate([ys, y_padding], axis=0)
|
||||
self.size = len(xs)
|
||||
self.num_batch = int(self.size // self.batch_size)
|
||||
if shuffle:
|
||||
permutation = np.random.permutation(self.size)
|
||||
xs, ys = xs[permutation], ys[permutation]
|
||||
self.xs = xs
|
||||
self.ys = ys
|
||||
|
||||
def get_iterator(self):
|
||||
self.current_ind = 0
|
||||
|
||||
def _wrapper():
|
||||
while self.current_ind < self.num_batch:
|
||||
start_ind = self.batch_size * self.current_ind
|
||||
end_ind = min(self.size, self.batch_size * (self.current_ind + 1))
|
||||
x_i = self.xs[start_ind: end_ind, ...]
|
||||
y_i = self.ys[start_ind: end_ind, ...]
|
||||
yield (x_i, y_i)
|
||||
self.current_ind += 1
|
||||
|
||||
return _wrapper()
|
||||
|
||||
|
||||
class StandardScaler:
|
||||
"""
|
||||
Standard the input
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def transform(self, data):
|
||||
return (data - self.mean) / self.std
|
||||
|
||||
def inverse_transform(self, data):
|
||||
return (data * self.std) + self.mean
|
||||
|
||||
|
||||
def calculate_random_walk_matrix(adj_mx):
|
||||
adj_mx = sp.coo_matrix(adj_mx)
|
||||
d = np.array(adj_mx.sum(1))
|
||||
d_inv = np.power(d, -1).flatten()
|
||||
d_inv[np.isinf(d_inv)] = 0.
|
||||
d_mat_inv = sp.diags(d_inv)
|
||||
random_walk_mx = d_mat_inv.dot(adj_mx).tocoo()
|
||||
return random_walk_mx
|
||||
|
||||
def config_logging(log_dir, log_filename='info.log', level=logging.INFO):
|
||||
# Add file handler and stdout handler
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
# Create the log directory if necessary.
|
||||
try:
|
||||
os.makedirs(log_dir)
|
||||
except OSError:
|
||||
pass
|
||||
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(level=level)
|
||||
# Add console handler.
|
||||
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
console_handler.setLevel(level=level)
|
||||
logging.basicConfig(handlers=[file_handler, console_handler], level=level)
|
||||
|
||||
|
||||
def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
# Add file handler and stdout handler
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
|
||||
file_handler.setFormatter(formatter)
|
||||
# Add console handler.
|
||||
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(console_handler)
|
||||
# Add google cloud log handler
|
||||
logger.info('Log directory: %s', log_dir)
|
||||
return logger
|
||||
|
||||
|
||||
def get_log_dir(kwargs):
|
||||
log_dir = kwargs['train'].get('log_dir')
|
||||
if log_dir is None:
|
||||
batch_size = kwargs['data'].get('batch_size')
|
||||
|
||||
filter_type = kwargs['model'].get('filter_type')
|
||||
gcn_step = kwargs['model'].get('gcn_step')
|
||||
horizon = kwargs['model'].get('horizon')
|
||||
latent_dim = kwargs['model'].get('latent_dim')
|
||||
n_traj_samples = kwargs['model'].get('n_traj_samples')
|
||||
ode_method = kwargs['model'].get('ode_method')
|
||||
|
||||
seq_len = kwargs['model'].get('seq_len')
|
||||
rnn_units = kwargs['model'].get('rnn_units')
|
||||
recg_type = kwargs['model'].get('recg_type')
|
||||
|
||||
if filter_type == 'unkP':
|
||||
filter_type_abbr = 'UP'
|
||||
elif filter_type == 'IncP':
|
||||
filter_type_abbr = 'NV'
|
||||
else:
|
||||
filter_type_abbr = 'DF'
|
||||
|
||||
|
||||
run_id = 'STDEN_%s-%d_%s-%d_L-%d_N-%d_M-%s_bs-%d_%d-%d_%s/' % (
|
||||
recg_type, rnn_units, filter_type_abbr, gcn_step, latent_dim, n_traj_samples, ode_method, batch_size, seq_len, horizon, time.strftime('%m%d%H%M%S'))
|
||||
base_dir = kwargs.get('log_base_dir')
|
||||
log_dir = os.path.join(base_dir, run_id)
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
return log_dir
|
||||
|
||||
|
||||
def load_dataset(dataset_dir, batch_size, val_batch_size=None, **kwargs):
|
||||
if('BJ' in dataset_dir):
|
||||
data = dict(np.load(os.path.join(dataset_dir, 'flow.npz'))) # convert readonly NpzFile to writable dict Object
|
||||
for category in ['train', 'val', 'test']:
|
||||
data['x_' + category] = data['x_' + category] #[..., :4] # ignore the time index
|
||||
else:
|
||||
data = {}
|
||||
for category in ['train', 'val', 'test']:
|
||||
cat_data = np.load(os.path.join(dataset_dir, category + '.npz'))
|
||||
data['x_' + category] = cat_data['x']
|
||||
data['y_' + category] = cat_data['y']
|
||||
scaler = StandardScaler(mean=data['x_train'].mean(), std=data['x_train'].std())
|
||||
# Data format
|
||||
for category in ['train', 'val', 'test']:
|
||||
data['x_' + category] = scaler.transform(data['x_' + category])
|
||||
data['y_' + category] = scaler.transform(data['y_' + category])
|
||||
data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True)
|
||||
data['val_loader'] = DataLoader(data['x_val'], data['y_val'], val_batch_size, shuffle=False)
|
||||
data['test_loader'] = DataLoader(data['x_test'], data['y_test'], val_batch_size, shuffle=False)
|
||||
data['scaler'] = scaler
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def load_graph_data(pkl_filename):
|
||||
adj_mx = np.load(pkl_filename)
|
||||
return adj_mx
|
||||
|
||||
def graph_grad(adj_mx):
|
||||
"""Fetch the graph gradient operator."""
|
||||
num_nodes = adj_mx.shape[0]
|
||||
|
||||
num_edges = (adj_mx > 0.).sum()
|
||||
grad = torch.zeros(num_nodes, num_edges)
|
||||
e = 0
|
||||
for i in range(num_nodes):
|
||||
for j in range(num_nodes):
|
||||
if adj_mx[i, j] == 0:
|
||||
continue
|
||||
|
||||
grad[i, e] = 1.
|
||||
grad[j, e] = -1.
|
||||
e += 1
|
||||
return grad
|
||||
|
||||
def init_network_weights(net, std = 0.1):
|
||||
"""
|
||||
Just for nn.Linear net.
|
||||
"""
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, mean=0, std=std)
|
||||
nn.init.constant_(m.bias, val=0)
|
||||
|
||||
def split_last_dim(data):
|
||||
last_dim = data.size()[-1]
|
||||
last_dim = last_dim//2
|
||||
|
||||
res = data[..., :last_dim], data[..., last_dim:]
|
||||
return res
|
||||
|
||||
def get_device(tensor):
|
||||
device = torch.device("cpu")
|
||||
if tensor.is_cuda:
|
||||
device = tensor.get_device()
|
||||
return device
|
||||
|
||||
def sample_standard_gaussian(mu, sigma):
|
||||
device = get_device(mu)
|
||||
|
||||
d = torch.distributions.normal.Normal(torch.Tensor([0.]).to(device), torch.Tensor([1.]).to(device))
|
||||
r = d.sample(mu.size()).squeeze(-1)
|
||||
return r * sigma.float() + mu.float()
|
||||
|
||||
def create_net(n_inputs, n_outputs, n_layers = 0,
|
||||
n_units = 100, nonlinear = nn.Tanh):
|
||||
layers = [nn.Linear(n_inputs, n_units)]
|
||||
for i in range(n_layers):
|
||||
layers.append(nonlinear())
|
||||
layers.append(nn.Linear(n_units, n_units))
|
||||
|
||||
layers.append(nonlinear())
|
||||
layers.append(nn.Linear(n_units, n_outputs))
|
||||
return nn.Sequential(*layers)
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
|
||||
from torchdiffeq import odeint
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
class DiffeqSolver(nn.Module):
|
||||
def __init__(self, odefunc, method, latent_dim,
|
||||
odeint_rtol = 1e-4, odeint_atol = 1e-5):
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.ode_method = method
|
||||
self.odefunc = odefunc
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
self.rtol = odeint_rtol
|
||||
self.atol = odeint_atol
|
||||
|
||||
def forward(self, first_point, time_steps_to_pred):
|
||||
"""
|
||||
Decoder the trajectory through the ODE Solver.
|
||||
|
||||
:param time_steps_to_pred: horizon
|
||||
:param first_point: (n_traj_samples, batch_size, num_nodes * latent_dim)
|
||||
:return: pred_y: # shape (horizon, n_traj_samples, batch_size, self.num_nodes * self.output_dim)
|
||||
"""
|
||||
n_traj_samples, batch_size = first_point.size()[0], first_point.size()[1]
|
||||
first_point = first_point.reshape(n_traj_samples * batch_size, -1) # reduce the complexity by merging dimension
|
||||
|
||||
# pred_y shape: (horizon, n_traj_samples * batch_size, num_nodes * latent_dim)
|
||||
start_time = time.time()
|
||||
self.odefunc.nfe = 0
|
||||
pred_y = odeint(self.odefunc,
|
||||
first_point,
|
||||
time_steps_to_pred,
|
||||
rtol=self.rtol,
|
||||
atol=self.atol,
|
||||
method=self.ode_method)
|
||||
time_fe = time.time() - start_time
|
||||
|
||||
# pred_y shape: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim)
|
||||
pred_y = pred_y.reshape(pred_y.size()[0], n_traj_samples, batch_size, -1)
|
||||
# assert(pred_y.size()[1] == n_traj_samples)
|
||||
# assert(pred_y.size()[2] == batch_size)
|
||||
|
||||
return pred_y, (self.odefunc.nfe, time_fe)
|
||||
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from lib import utils
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
class LayerParams:
|
||||
def __init__(self, rnn_network: nn.Module, layer_type: str):
|
||||
self._rnn_network = rnn_network
|
||||
self._params_dict = {}
|
||||
self._biases_dict = {}
|
||||
self._type = layer_type
|
||||
|
||||
def get_weights(self, shape):
|
||||
if shape not in self._params_dict:
|
||||
nn_param = nn.Parameter(torch.empty(*shape, device=device))
|
||||
nn.init.xavier_normal_(nn_param)
|
||||
self._params_dict[shape] = nn_param
|
||||
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
|
||||
nn_param)
|
||||
return self._params_dict[shape]
|
||||
|
||||
def get_biases(self, length, bias_start=0.0):
|
||||
if length not in self._biases_dict:
|
||||
biases = nn.Parameter(torch.empty(length, device=device))
|
||||
nn.init.constant_(biases, bias_start)
|
||||
self._biases_dict[length] = biases
|
||||
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
|
||||
biases)
|
||||
|
||||
return self._biases_dict[length]
|
||||
|
||||
class ODEFunc(nn.Module):
|
||||
def __init__(self, num_units, latent_dim, adj_mx, gcn_step, num_nodes,
|
||||
gen_layers=1, nonlinearity='tanh', filter_type="default"):
|
||||
"""
|
||||
:param num_units: dimensionality of the hidden layers
|
||||
:param latent_dim: dimensionality used for ODE (input and output). Analog of a continous latent state
|
||||
:param adj_mx:
|
||||
:param gcn_step:
|
||||
:param num_nodes:
|
||||
:param gen_layers: hidden layers in each ode func.
|
||||
:param nonlinearity:
|
||||
:param filter_type: default
|
||||
:param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates.
|
||||
"""
|
||||
super(ODEFunc, self).__init__()
|
||||
self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu
|
||||
|
||||
self._num_nodes = num_nodes
|
||||
self._num_units = num_units # hidden dimension
|
||||
self._latent_dim = latent_dim
|
||||
self._gen_layers = gen_layers
|
||||
self.nfe = 0
|
||||
|
||||
self._filter_type = filter_type
|
||||
if(self._filter_type == "unkP"):
|
||||
ode_func_net = utils.create_net(latent_dim, latent_dim, n_units=num_units)
|
||||
utils.init_network_weights(ode_func_net)
|
||||
self.gradient_net = ode_func_net
|
||||
else:
|
||||
self._gcn_step = gcn_step
|
||||
self._gconv_params = LayerParams(self, 'gconv')
|
||||
self._supports = []
|
||||
supports = []
|
||||
supports.append(utils.calculate_random_walk_matrix(adj_mx).T)
|
||||
supports.append(utils.calculate_random_walk_matrix(adj_mx.T).T)
|
||||
|
||||
for support in supports:
|
||||
self._supports.append(self._build_sparse_matrix(support))
|
||||
|
||||
@staticmethod
|
||||
def _build_sparse_matrix(L):
|
||||
L = L.tocoo()
|
||||
indices = np.column_stack((L.row, L.col))
|
||||
# this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L)
|
||||
indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))]
|
||||
L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device)
|
||||
return L
|
||||
|
||||
def forward(self, t_local, y, backwards = False):
|
||||
"""
|
||||
Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point
|
||||
|
||||
t_local: current time point
|
||||
y: value at the current time point, shape (B, num_nodes * latent_dim)
|
||||
|
||||
:return
|
||||
- Output: A `2-D` tensor with shape `(B, num_nodes * latent_dim)`.
|
||||
"""
|
||||
self.nfe += 1
|
||||
grad = self.get_ode_gradient_nn(t_local, y)
|
||||
if backwards:
|
||||
grad = -grad
|
||||
return grad
|
||||
|
||||
def get_ode_gradient_nn(self, t_local, inputs):
|
||||
if(self._filter_type == "unkP"):
|
||||
grad = self._fc(inputs)
|
||||
elif (self._filter_type == "IncP"):
|
||||
grad = - self.ode_func_net(inputs)
|
||||
else: # default is diffusion process
|
||||
# theta shape: (B, num_nodes * latent_dim)
|
||||
theta = torch.sigmoid(self._gconv(inputs, self._latent_dim, bias_start=1.0))
|
||||
grad = - theta * self.ode_func_net(inputs)
|
||||
return grad
|
||||
|
||||
def ode_func_net(self, inputs):
|
||||
c = inputs
|
||||
for i in range(self._gen_layers):
|
||||
c = self._gconv(c, self._num_units)
|
||||
c = self._activation(c)
|
||||
c = self._gconv(c, self._latent_dim)
|
||||
c = self._activation(c)
|
||||
return c
|
||||
|
||||
def _fc(self, inputs):
|
||||
batch_size = inputs.size()[0]
|
||||
grad = self.gradient_net(inputs.view(batch_size * self._num_nodes, self._latent_dim))
|
||||
return grad.reshape(batch_size, self._num_nodes * self._latent_dim) # (batch_size, num_nodes, latent_dim)
|
||||
|
||||
@staticmethod
|
||||
def _concat(x, x_):
|
||||
x_ = x_.unsqueeze(0)
|
||||
return torch.cat([x, x_], dim=0)
|
||||
|
||||
def _gconv(self, inputs, output_size, bias_start=0.0):
|
||||
# Reshape input and state to (batch_size, num_nodes, input_dim/state_dim)
|
||||
batch_size = inputs.shape[0]
|
||||
inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1))
|
||||
# state = torch.reshape(state, (batch_size, self._num_nodes, -1))
|
||||
# inputs_and_state = torch.cat([inputs, state], dim=2)
|
||||
input_size = inputs.size(2)
|
||||
|
||||
x = inputs
|
||||
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
|
||||
x0 = torch.reshape(x0, shape=[self._num_nodes, input_size * batch_size])
|
||||
x = torch.unsqueeze(x0, 0)
|
||||
|
||||
if self._gcn_step == 0:
|
||||
pass
|
||||
else:
|
||||
for support in self._supports:
|
||||
x1 = torch.sparse.mm(support, x0)
|
||||
x = self._concat(x, x1)
|
||||
|
||||
for k in range(2, self._gcn_step + 1):
|
||||
x2 = 2 * torch.sparse.mm(support, x1) - x0
|
||||
x = self._concat(x, x2)
|
||||
x1, x0 = x2, x1
|
||||
|
||||
num_matrices = len(self._supports) * self._gcn_step + 1 # Adds for x itself.
|
||||
x = torch.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size])
|
||||
x = x.permute(3, 1, 2, 0) # (batch_size, num_nodes, input_size, order)
|
||||
x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices])
|
||||
|
||||
weights = self._gconv_params.get_weights((input_size * num_matrices, output_size))
|
||||
x = torch.matmul(x, weights) # (batch_size * self._num_nodes, output_size)
|
||||
|
||||
biases = self._gconv_params.get_biases(output_size, bias_start)
|
||||
x += biases
|
||||
# Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim)
|
||||
return torch.reshape(x, [batch_size, self._num_nodes * output_size])
|
||||
|
|
@ -0,0 +1,206 @@
|
|||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.nn.modules.rnn import GRU
|
||||
from model.ode_func import ODEFunc
|
||||
from model.diffeq_solver import DiffeqSolver
|
||||
|
||||
from lib import utils
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
class EncoderAttrs:
|
||||
def __init__(self, adj_mx, **model_kwargs):
|
||||
self.adj_mx = adj_mx
|
||||
self.num_nodes = adj_mx.shape[0]
|
||||
self.num_edges = (adj_mx > 0.).sum()
|
||||
self.gcn_step = int(model_kwargs.get('gcn_step', 2))
|
||||
self.filter_type = model_kwargs.get('filter_type', 'default')
|
||||
self.num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1))
|
||||
self.rnn_units = int(model_kwargs.get('rnn_units'))
|
||||
self.latent_dim = int(model_kwargs.get('latent_dim', 4))
|
||||
|
||||
class STDENModel(nn.Module, EncoderAttrs):
|
||||
def __init__(self, adj_mx, logger, **model_kwargs):
|
||||
nn.Module.__init__(self)
|
||||
EncoderAttrs.__init__(self, adj_mx, **model_kwargs)
|
||||
self._logger = logger
|
||||
####################################################
|
||||
# recognition net
|
||||
####################################################
|
||||
self.encoder_z0 = Encoder_z0_RNN(adj_mx, **model_kwargs)
|
||||
|
||||
####################################################
|
||||
# ode solver
|
||||
####################################################
|
||||
self.n_traj_samples = int(model_kwargs.get('n_traj_samples', 1))
|
||||
self.ode_method = model_kwargs.get('ode_method', 'dopri5')
|
||||
self.atol = float(model_kwargs.get('odeint_atol', 1e-4))
|
||||
self.rtol = float(model_kwargs.get('odeint_rtol', 1e-3))
|
||||
self.num_gen_layer = int(model_kwargs.get('gen_layers', 1))
|
||||
self.ode_gen_dim = int(model_kwargs.get('gen_dim', 64))
|
||||
ode_set_str = "ODE setting --latent {} --samples {} --method {} \
|
||||
--atol {:6f} --rtol {:6f} --gen_layer {} --gen_dim {}".format(\
|
||||
self.latent_dim, self.n_traj_samples, self.ode_method, \
|
||||
self.atol, self.rtol, self.num_gen_layer, self.ode_gen_dim)
|
||||
odefunc = ODEFunc(self.ode_gen_dim, # hidden dimension
|
||||
self.latent_dim,
|
||||
adj_mx,
|
||||
self.gcn_step,
|
||||
self.num_nodes,
|
||||
filter_type=self.filter_type
|
||||
).to(device)
|
||||
self.diffeq_solver = DiffeqSolver(odefunc,
|
||||
self.ode_method,
|
||||
self.latent_dim,
|
||||
odeint_rtol=self.rtol,
|
||||
odeint_atol=self.atol
|
||||
)
|
||||
self._logger.info(ode_set_str)
|
||||
|
||||
self.save_latent = bool(model_kwargs.get('save_latent', False))
|
||||
self.latent_feat = None # used to extract the latent feature
|
||||
|
||||
####################################################
|
||||
# decoder
|
||||
####################################################
|
||||
self.horizon = int(model_kwargs.get('horizon', 1))
|
||||
self.out_feat = int(model_kwargs.get('output_dim', 1))
|
||||
self.decoder = Decoder(
|
||||
self.out_feat,
|
||||
adj_mx,
|
||||
self.num_nodes,
|
||||
self.num_edges,
|
||||
).to(device)
|
||||
|
||||
##########################################
|
||||
def forward(self, inputs, labels=None, batches_seen=None):
|
||||
"""
|
||||
seq2seq forward pass
|
||||
:param inputs: shape (seq_len, batch_size, num_edges * input_dim)
|
||||
:param labels: shape (horizon, batch_size, num_edges * output_dim)
|
||||
:param batches_seen: batches seen till now
|
||||
:return: outputs: (self.horizon, batch_size, self.num_edges * self.output_dim)
|
||||
"""
|
||||
perf_time = time.time()
|
||||
# shape: [1, batch, num_nodes * latent_dim]
|
||||
first_point_mu, first_point_std = self.encoder_z0(inputs)
|
||||
self._logger.debug("Recognition complete with {:.1f}s".format(time.time() - perf_time))
|
||||
|
||||
# sample 'n_traj_samples' trajectory
|
||||
perf_time = time.time()
|
||||
means_z0 = first_point_mu.repeat(self.n_traj_samples, 1, 1)
|
||||
sigma_z0 = first_point_std.repeat(self.n_traj_samples, 1, 1)
|
||||
first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0)
|
||||
|
||||
time_steps_to_predict = torch.arange(start=0, end=self.horizon, step=1).float().to(device)
|
||||
time_steps_to_predict = time_steps_to_predict / len(time_steps_to_predict)
|
||||
|
||||
# Shape of sol_ys (horizon, n_traj_samples, batch_size, self.num_nodes * self.latent_dim)
|
||||
sol_ys, fe = self.diffeq_solver(first_point_enc, time_steps_to_predict)
|
||||
self._logger.debug("ODE solver complete with {:.1f}s".format(time.time() - perf_time))
|
||||
if(self.save_latent):
|
||||
# Shape of latent_feat (horizon, batch_size, self.num_nodes * self.latent_dim)
|
||||
self.latent_feat = torch.mean(sol_ys.detach(), axis=1)
|
||||
|
||||
perf_time = time.time()
|
||||
outputs = self.decoder(sol_ys)
|
||||
self._logger.debug("Decoder complete with {:.1f}s".format(time.time() - perf_time))
|
||||
|
||||
if batches_seen == 0:
|
||||
self._logger.info(
|
||||
"Total trainable parameters {}".format(count_parameters(self))
|
||||
)
|
||||
return outputs, fe
|
||||
|
||||
class Encoder_z0_RNN(nn.Module, EncoderAttrs):
|
||||
def __init__(self, adj_mx, **model_kwargs):
|
||||
nn.Module.__init__(self)
|
||||
EncoderAttrs.__init__(self, adj_mx, **model_kwargs)
|
||||
self.recg_type = model_kwargs.get('recg_type', 'gru') # gru
|
||||
|
||||
if(self.recg_type == 'gru'):
|
||||
# gru settings
|
||||
self.input_dim = int(model_kwargs.get('input_dim', 1))
|
||||
self.gru_rnn = GRU(self.input_dim, self.rnn_units).to(device)
|
||||
else:
|
||||
raise NotImplementedError("The recognition net only support 'gru'.")
|
||||
|
||||
# hidden to z0 settings
|
||||
self.inv_grad = utils.graph_grad(adj_mx).transpose(-2, -1)
|
||||
self.inv_grad[self.inv_grad != 0.] = 0.5
|
||||
self.hiddens_to_z0 = nn.Sequential(
|
||||
nn.Linear(self.rnn_units, 50),
|
||||
nn.Tanh(),
|
||||
nn.Linear(50, self.latent_dim * 2),)
|
||||
|
||||
utils.init_network_weights(self.hiddens_to_z0)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
encoder forward pass on t time steps
|
||||
:param inputs: shape (seq_len, batch_size, num_edges * input_dim)
|
||||
:return: mean, std: # shape (n_samples=1, batch_size, self.latent_dim)
|
||||
"""
|
||||
if(self.recg_type == 'gru'):
|
||||
# shape of outputs: (seq_len, batch, num_senor * rnn_units)
|
||||
seq_len, batch_size = inputs.size(0), inputs.size(1)
|
||||
inputs = inputs.reshape(seq_len, batch_size, self.num_edges, self.input_dim)
|
||||
inputs = inputs.reshape(seq_len, batch_size * self.num_edges, self.input_dim)
|
||||
|
||||
outputs, _ = self.gru_rnn(inputs)
|
||||
last_output = outputs[-1]
|
||||
# (batch_size, num_edges, rnn_units)
|
||||
last_output = torch.reshape(last_output, (batch_size, self.num_edges, -1))
|
||||
last_output = torch.transpose(last_output, (-2, -1))
|
||||
# (batch_size, num_nodes, rnn_units)
|
||||
last_output = torch.matmul(last_output, self.inv_grad).transpose(-2, -1)
|
||||
else:
|
||||
raise NotImplementedError("The recognition net only support 'gru'.")
|
||||
|
||||
mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output))
|
||||
mean = mean.reshape(batch_size, -1) # (batch_size, num_nodes * latent_dim)
|
||||
std = std.reshape(batch_size, -1) # (batch_size, num_nodes * latent_dim)
|
||||
std = std.abs()
|
||||
|
||||
assert(not torch.isnan(mean).any())
|
||||
assert(not torch.isnan(std).any())
|
||||
|
||||
return mean.unsqueeze(0), std.unsqueeze(0) # for n_sample traj
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, output_dim, adj_mx, num_nodes, num_edges):
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
self.num_nodes = num_nodes
|
||||
self.num_edges = num_edges
|
||||
self.grap_grad = utils.graph_grad(adj_mx)
|
||||
|
||||
self.output_dim = output_dim
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
:param inputs: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim)
|
||||
:return outputs: (horizon, batch_size, num_edges * output_dim), average result of n_traj_samples.
|
||||
"""
|
||||
assert(len(inputs.size()) == 4)
|
||||
horizon, n_traj_samples, batch_size = inputs.size()[:3]
|
||||
|
||||
inputs = inputs.reshape(horizon, n_traj_samples, batch_size, self.num_nodes, -1).transpose(-2, -1)
|
||||
latent_dim = inputs.size(-2)
|
||||
# transform z with shape `(..., num_nodes)` to f with shape `(..., num_edges)`.
|
||||
outputs = torch.matmul(inputs, self.grap_grad)
|
||||
|
||||
outputs = outputs.reshape(horizon, n_traj_samples, batch_size, latent_dim, self.num_edges, self.output_dim)
|
||||
outputs = torch.mean(
|
||||
torch.mean(outputs, axis=3),
|
||||
axis=1
|
||||
)
|
||||
outputs = outputs.reshape(horizon, batch_size, -1)
|
||||
return outputs
|
||||
|
||||
|
|
@ -0,0 +1,415 @@
|
|||
import os
|
||||
import time
|
||||
from random import SystemRandom
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from lib import utils
|
||||
from model.stden_model import STDENModel
|
||||
from lib.metrics import masked_mae_loss, masked_mape_loss, masked_rmse_loss
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
class STDENSupervisor:
|
||||
def __init__(self, adj_mx, **kwargs):
|
||||
self._kwargs = kwargs
|
||||
self._data_kwargs = kwargs.get('data')
|
||||
self._model_kwargs = kwargs.get('model')
|
||||
self._train_kwargs = kwargs.get('train')
|
||||
|
||||
self.max_grad_norm = self._train_kwargs.get('max_grad_norm', 1.)
|
||||
|
||||
# logging.
|
||||
self._log_dir = utils.get_log_dir(kwargs)
|
||||
self._writer = SummaryWriter('runs/' + self._log_dir)
|
||||
|
||||
log_level = self._kwargs.get('log_level', 'INFO')
|
||||
self._logger = utils.get_logger(self._log_dir, __name__, 'info.log', level=log_level)
|
||||
|
||||
# data set
|
||||
self._data = utils.load_dataset(**self._data_kwargs)
|
||||
self.standard_scaler = self._data['scaler']
|
||||
self._logger.info('Scaler mean: {:.6f}, std {:.6f}.'.format(self.standard_scaler.mean, self.standard_scaler.std))
|
||||
|
||||
self.num_edges = (adj_mx > 0.).sum()
|
||||
self.input_dim = int(self._model_kwargs.get('input_dim', 1))
|
||||
self.seq_len = int(self._model_kwargs.get('seq_len')) # for the encoder
|
||||
self.output_dim = int(self._model_kwargs.get('output_dim', 1))
|
||||
self.use_curriculum_learning = bool(
|
||||
self._model_kwargs.get('use_curriculum_learning', False))
|
||||
self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder
|
||||
|
||||
# setup model
|
||||
stden_model = STDENModel(adj_mx, self._logger, **self._model_kwargs)
|
||||
self.stden_model = stden_model.cuda() if torch.cuda.is_available() else stden_model
|
||||
self._logger.info("Model created")
|
||||
|
||||
self.experimentID = self._train_kwargs.get('load', 0)
|
||||
if self.experimentID == 0:
|
||||
# Make a new experiment ID
|
||||
self.experimentID = int(SystemRandom().random()*100000)
|
||||
self.ckpt_path = os.path.join("ckpt/", "experiment_" + str(self.experimentID))
|
||||
|
||||
self._epoch_num = self._train_kwargs.get('epoch', 0)
|
||||
if self._epoch_num > 0:
|
||||
self._logger.info('Loading model...')
|
||||
self.load_model()
|
||||
|
||||
def save_model(self, epoch):
|
||||
model_dir = self.ckpt_path
|
||||
if not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
|
||||
config = dict(self._kwargs)
|
||||
config['model_state_dict'] = self.stden_model.state_dict()
|
||||
config['epoch'] = epoch
|
||||
model_path = os.path.join(model_dir, 'epo{}.tar'.format(epoch))
|
||||
torch.save(config, model_path)
|
||||
self._logger.info("Saved model at {}".format(epoch))
|
||||
return model_path
|
||||
|
||||
def load_model(self):
|
||||
self._setup_graph()
|
||||
model_path = os.path.join(self.ckpt_path, 'epo{}.tar'.format(self._epoch_num))
|
||||
assert os.path.exists(model_path), 'Weights at epoch %d not found' % self._epoch_num
|
||||
|
||||
checkpoint = torch.load(model_path, map_location='cpu')
|
||||
self.stden_model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self._logger.info("Loaded model at {}".format(self._epoch_num))
|
||||
|
||||
def _setup_graph(self):
|
||||
with torch.no_grad():
|
||||
self.stden_model.eval()
|
||||
|
||||
val_iterator = self._data['val_loader'].get_iterator()
|
||||
|
||||
for _, (x, y) in enumerate(val_iterator):
|
||||
x, y = self._prepare_data(x, y)
|
||||
output = self.stden_model(x)
|
||||
break
|
||||
|
||||
def train(self, **kwargs):
|
||||
self._logger.info('Model mode: train')
|
||||
kwargs.update(self._train_kwargs)
|
||||
return self._train(**kwargs)
|
||||
|
||||
def _train(self, base_lr,
|
||||
steps, patience=50, epochs=100, lr_decay_ratio=0.1, log_every=1, save_model=1,
|
||||
test_every_n_epochs=10, epsilon=1e-8, **kwargs):
|
||||
# steps is used in learning rate - will see if need to use it?
|
||||
min_val_loss = float('inf')
|
||||
wait = 0
|
||||
optimizer = torch.optim.Adam(self.stden_model.parameters(), lr=base_lr, eps=epsilon)
|
||||
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps,
|
||||
gamma=lr_decay_ratio)
|
||||
|
||||
self._logger.info('Start training ...')
|
||||
|
||||
# this will fail if model is loaded with a changed batch_size
|
||||
num_batches = self._data['train_loader'].num_batch
|
||||
self._logger.info("num_batches: {}".format(num_batches))
|
||||
|
||||
batches_seen = num_batches * self._epoch_num
|
||||
|
||||
# used for nfe
|
||||
c = []
|
||||
res, keys = [], []
|
||||
|
||||
for epoch_num in range(self._epoch_num, epochs):
|
||||
|
||||
self.stden_model.train()
|
||||
|
||||
train_iterator = self._data['train_loader'].get_iterator()
|
||||
losses = []
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
c.clear() #nfe
|
||||
for i, (x, y) in enumerate(train_iterator):
|
||||
if(i >= num_batches):
|
||||
break
|
||||
optimizer.zero_grad()
|
||||
|
||||
x, y = self._prepare_data(x, y)
|
||||
|
||||
output, fe = self.stden_model(x, y, batches_seen)
|
||||
|
||||
if batches_seen == 0:
|
||||
# this is a workaround to accommodate dynamically registered parameters
|
||||
optimizer = torch.optim.Adam(self.stden_model.parameters(), lr=base_lr, eps=epsilon)
|
||||
|
||||
loss = self._compute_loss(y, output)
|
||||
self._logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item()))
|
||||
c.append([*fe, loss.item()])
|
||||
|
||||
self._logger.debug(loss.item())
|
||||
losses.append(loss.item())
|
||||
|
||||
batches_seen += 1 # global step in tensorboard
|
||||
loss.backward()
|
||||
|
||||
# gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(self.stden_model.parameters(), self.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
del x, y, output, loss # del make these memory no-labeled trash
|
||||
torch.cuda.empty_cache() # empty_cache() recycle no-labeled trash
|
||||
|
||||
# used for nfe
|
||||
res.append(pd.DataFrame(c, columns=['nfe', 'time', 'err']))
|
||||
keys.append(epoch_num)
|
||||
|
||||
self._logger.info("epoch complete")
|
||||
lr_scheduler.step()
|
||||
self._logger.info("evaluating now!")
|
||||
|
||||
val_loss, _ = self.evaluate(dataset='val', batches_seen=batches_seen)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
self._writer.add_scalar('training loss',
|
||||
np.mean(losses),
|
||||
batches_seen)
|
||||
|
||||
if (epoch_num % log_every) == log_every - 1:
|
||||
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}, ' \
|
||||
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
||||
np.mean(losses), val_loss, lr_scheduler.get_lr()[0],
|
||||
(end_time - start_time))
|
||||
self._logger.info(message)
|
||||
|
||||
if (epoch_num % test_every_n_epochs) == test_every_n_epochs - 1:
|
||||
test_loss, _ = self.evaluate(dataset='test', batches_seen=batches_seen)
|
||||
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f}, ' \
|
||||
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
||||
np.mean(losses), test_loss, lr_scheduler.get_lr()[0],
|
||||
(end_time - start_time))
|
||||
self._logger.info(message)
|
||||
|
||||
if val_loss < min_val_loss:
|
||||
wait = 0
|
||||
if save_model:
|
||||
model_file_name = self.save_model(epoch_num)
|
||||
self._logger.info(
|
||||
'Val loss decrease from {:.4f} to {:.4f}, '
|
||||
'saving to {}'.format(min_val_loss, val_loss, model_file_name))
|
||||
min_val_loss = val_loss
|
||||
|
||||
elif val_loss >= min_val_loss:
|
||||
wait += 1
|
||||
if wait == patience:
|
||||
self._logger.warning('Early stopping at epoch: %d' % epoch_num)
|
||||
break
|
||||
|
||||
if bool(self._model_kwargs.get('nfe', False)):
|
||||
res = pd.concat(res, keys=keys)
|
||||
# self._logger.info("res.shape: ", res.shape)
|
||||
res.index.names = ['epoch', 'iter']
|
||||
filter_type = self._model_kwargs.get('filter_type', 'unknown')
|
||||
atol = float(self._model_kwargs.get('odeint_atol', 1e-5))
|
||||
rtol = float(self._model_kwargs.get('odeint_rtol', 1e-5))
|
||||
nfe_file = os.path.join(
|
||||
self._data_kwargs.get('dataset_dir', 'data'),
|
||||
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
|
||||
res.to_pickle(nfe_file)
|
||||
# res.to_csv(nfe_file)
|
||||
|
||||
def _prepare_data(self, x, y):
|
||||
x, y = self._get_x_y(x, y)
|
||||
x, y = self._get_x_y_in_correct_dims(x, y)
|
||||
return x.to(device), y.to(device)
|
||||
|
||||
def _get_x_y(self, x, y):
|
||||
"""
|
||||
:param x: shape (batch_size, seq_len, num_edges, input_dim)
|
||||
:param y: shape (batch_size, horizon, num_edges, input_dim)
|
||||
:returns x shape (seq_len, batch_size, num_edges, input_dim)
|
||||
y shape (horizon, batch_size, num_edges, input_dim)
|
||||
"""
|
||||
x = torch.from_numpy(x).float()
|
||||
y = torch.from_numpy(y).float()
|
||||
self._logger.debug("X: {}".format(x.size()))
|
||||
self._logger.debug("y: {}".format(y.size()))
|
||||
x = x.permute(1, 0, 2, 3)
|
||||
y = y.permute(1, 0, 2, 3)
|
||||
return x, y
|
||||
|
||||
def _get_x_y_in_correct_dims(self, x, y):
|
||||
"""
|
||||
:param x: shape (seq_len, batch_size, num_edges, input_dim)
|
||||
:param y: shape (horizon, batch_size, num_edges, input_dim)
|
||||
:return: x: shape (seq_len, batch_size, num_edges * input_dim)
|
||||
y: shape (horizon, batch_size, num_edges * output_dim)
|
||||
"""
|
||||
batch_size = x.size(1)
|
||||
self._logger.debug("size of x {}".format(x.size()))
|
||||
x = x.view(self.seq_len, batch_size, self.num_edges * self.input_dim)
|
||||
y = y[..., :self.output_dim].view(self.horizon, batch_size,
|
||||
self.num_edges * self.output_dim)
|
||||
return x, y
|
||||
|
||||
def _compute_loss(self, y_true, y_predicted):
|
||||
y_true = self.standard_scaler.inverse_transform(y_true)
|
||||
y_predicted = self.standard_scaler.inverse_transform(y_predicted)
|
||||
return masked_mae_loss(y_predicted, y_true)
|
||||
|
||||
def _compute_loss_eval(self, y_true, y_predicted):
|
||||
y_true = self.standard_scaler.inverse_transform(y_true)
|
||||
y_predicted = self.standard_scaler.inverse_transform(y_predicted)
|
||||
return masked_mae_loss(y_predicted, y_true).item(), masked_mape_loss(y_predicted, y_true).item(), masked_rmse_loss(y_predicted, y_true).item()
|
||||
|
||||
def evaluate(self, dataset='val', batches_seen=0, save=False):
|
||||
"""
|
||||
Computes mae rmse mape loss and the predict if save
|
||||
:return: mean L1Loss
|
||||
"""
|
||||
with torch.no_grad():
|
||||
self.stden_model.eval()
|
||||
|
||||
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
||||
mae_losses = []
|
||||
mape_losses = []
|
||||
rmse_losses = []
|
||||
y_dict = None
|
||||
|
||||
if(save):
|
||||
y_truths = []
|
||||
y_preds = []
|
||||
|
||||
for _, (x, y) in enumerate(val_iterator):
|
||||
x, y = self._prepare_data(x, y)
|
||||
|
||||
output, fe = self.stden_model(x)
|
||||
mae, mape, rmse = self._compute_loss_eval(y, output)
|
||||
mae_losses.append(mae)
|
||||
mape_losses.append(mape)
|
||||
rmse_losses.append(rmse)
|
||||
|
||||
if(save):
|
||||
y_truths.append(y.cpu())
|
||||
y_preds.append(output.cpu())
|
||||
|
||||
mean_loss = {
|
||||
'mae': np.mean(mae_losses),
|
||||
'mape': np.mean(mape_losses),
|
||||
'rmse': np.mean(rmse_losses)
|
||||
}
|
||||
|
||||
self._logger.info('Evaluation: - mae - {:.4f} - mape - {:.4f} - rmse - {:.4f}'.format(mean_loss['mae'], mean_loss['mape'], mean_loss['rmse']))
|
||||
self._writer.add_scalar('{} loss'.format(dataset), mean_loss['mae'], batches_seen)
|
||||
|
||||
if(save):
|
||||
y_preds = np.concatenate(y_preds, axis=1)
|
||||
y_truths = np.concatenate(y_truths, axis=1) # concatenate on batch dimension
|
||||
|
||||
y_truths_scaled = []
|
||||
y_preds_scaled = []
|
||||
# self._logger.debug("y_preds shape: {}, y_truth shape {}".format(y_preds.shape, y_truths.shape))
|
||||
for t in range(y_preds.shape[0]):
|
||||
y_truth = self.standard_scaler.inverse_transform(y_truths[t])
|
||||
y_pred = self.standard_scaler.inverse_transform(y_preds[t])
|
||||
y_truths_scaled.append(y_truth)
|
||||
y_preds_scaled.append(y_pred)
|
||||
|
||||
y_preds_scaled = np.stack(y_preds_scaled)
|
||||
y_truths_scaled = np.stack(y_truths_scaled)
|
||||
|
||||
y_dict = {'prediction': y_preds_scaled, 'truth': y_truths_scaled}
|
||||
|
||||
# save_dir = self._data_kwargs.get('dataset_dir', 'data')
|
||||
# save_path = os.path.join(save_dir, 'pred.npz')
|
||||
# np.savez(save_path, prediction=y_preds_scaled, turth=y_truths_scaled)
|
||||
|
||||
return mean_loss['mae'], y_dict
|
||||
|
||||
def eval_more(self, dataset='val', save=False, seq_len=[3, 6, 9, 12], extract_latent=False):
|
||||
"""
|
||||
Computes mae rmse mape loss and the prediction if `save` is set True.
|
||||
"""
|
||||
self._logger.info('Model mode: Evaluation')
|
||||
with torch.no_grad():
|
||||
self.stden_model.eval()
|
||||
|
||||
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
||||
mae_losses = []
|
||||
mape_losses = []
|
||||
rmse_losses = []
|
||||
|
||||
if(save):
|
||||
y_truths = []
|
||||
y_preds = []
|
||||
|
||||
if(extract_latent):
|
||||
latents = []
|
||||
|
||||
# used for nfe
|
||||
c = []
|
||||
for _, (x, y) in enumerate(val_iterator):
|
||||
x, y = self._prepare_data(x, y)
|
||||
|
||||
output, fe = self.stden_model(x)
|
||||
mae, mape, rmse = [], [], []
|
||||
for seq in seq_len:
|
||||
_mae, _mape, _rmse = self._compute_loss_eval(y[seq-1], output[seq-1])
|
||||
mae.append(_mae)
|
||||
mape.append(_mape)
|
||||
rmse.append(_rmse)
|
||||
mae_losses.append(mae)
|
||||
mape_losses.append(mape)
|
||||
rmse_losses.append(rmse)
|
||||
c.append([*fe, np.mean(mae)])
|
||||
|
||||
if(save):
|
||||
y_truths.append(y.cpu())
|
||||
y_preds.append(output.cpu())
|
||||
|
||||
if(extract_latent):
|
||||
latents.append(self.stden_model.latent_feat.cpu())
|
||||
|
||||
mean_loss = {
|
||||
'mae': np.mean(mae_losses, axis=0),
|
||||
'mape': np.mean(mape_losses, axis=0),
|
||||
'rmse': np.mean(rmse_losses, axis=0)
|
||||
}
|
||||
|
||||
for i, seq in enumerate(seq_len):
|
||||
self._logger.info('Evaluation seq {}: - mae - {:.4f} - mape - {:.4f} - rmse - {:.4f}'.format(
|
||||
seq, mean_loss['mae'][i], mean_loss['mape'][i], mean_loss['rmse'][i]))
|
||||
|
||||
if(save):
|
||||
# shape (horizon, num_sapmles, feat_dim)
|
||||
y_preds = np.concatenate(y_preds, axis=1)
|
||||
y_truths = np.concatenate(y_truths, axis=1) # concatenate on batch dimension
|
||||
y_preds_scaled = self.standard_scaler.inverse_transform(y_preds)
|
||||
y_truths_scaled = self.standard_scaler.inverse_transform(y_truths)
|
||||
|
||||
save_dir = self._data_kwargs.get('dataset_dir', 'data')
|
||||
save_path = os.path.join(save_dir, 'pred_{}_{}.npz'.format(self.experimentID, self._epoch_num))
|
||||
np.savez_compressed(save_path, prediction=y_preds_scaled, turth=y_truths_scaled)
|
||||
|
||||
if(extract_latent):
|
||||
# concatenate on batch dimension
|
||||
latents = np.concatenate(latents, axis=1)
|
||||
# Shape of latents (horizon, num_samples, self.num_edges * self.output_dim)
|
||||
|
||||
save_dir = self._data_kwargs.get('dataset_dir', 'data')
|
||||
filter_type = self._model_kwargs.get('filter_type', 'unknown')
|
||||
save_path = os.path.join(save_dir, '{}_latent_{}_{}.npz'.format(filter_type, self.experimentID, self._epoch_num))
|
||||
np.savez_compressed(save_path, latent=latents)
|
||||
|
||||
if bool(self._model_kwargs.get('nfe', False)):
|
||||
res = pd.DataFrame(c, columns=['nfe', 'time', 'err'])
|
||||
res.index.name = 'iter'
|
||||
filter_type = self._model_kwargs.get('filter_type', 'unknown')
|
||||
atol = float(self._model_kwargs.get('odeint_atol', 1e-5))
|
||||
rtol = float(self._model_kwargs.get('odeint_rtol', 1e-5))
|
||||
nfe_file = os.path.join(
|
||||
self._data_kwargs.get('dataset_dir', 'data'),
|
||||
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
|
||||
res.to_pickle(nfe_file)
|
||||
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# STDEN项目依赖包
|
||||
# 深度学习框架
|
||||
torch>=1.9.0
|
||||
torchvision>=0.10.0
|
||||
|
||||
# 科学计算
|
||||
numpy>=1.21.0
|
||||
scipy>=1.7.0
|
||||
pandas>=1.3.0
|
||||
|
||||
# 机器学习工具
|
||||
scikit-learn>=1.0.0
|
||||
|
||||
# 可视化
|
||||
matplotlib>=3.4.0
|
||||
tensorboard>=2.6.0
|
||||
|
||||
# 配置和日志
|
||||
PyYAML>=5.4.0
|
||||
|
||||
# 其他工具
|
||||
tqdm>=4.62.0
|
||||
pathlib2>=2.3.0
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
STDEN项目主运行文件
|
||||
根据模型名称自动调用对应的配置、数据加载器和训练器
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import yaml
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.append(str(project_root))
|
||||
|
||||
from lib.logger import setup_logger
|
||||
from lib.utils import load_graph_data
|
||||
from trainer.stden_trainer import STDENTrainer
|
||||
from dataloader.stden_dataloader import STDENDataloader
|
||||
|
||||
|
||||
def load_config(config_path):
|
||||
"""加载YAML配置文件"""
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
return config
|
||||
|
||||
|
||||
def setup_environment(config):
|
||||
"""设置环境变量和随机种子"""
|
||||
# 设置随机种子
|
||||
random_seed = config.get('random_seed', 2021)
|
||||
torch.manual_seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
|
||||
# 设置设备
|
||||
device = 'cuda' if torch.cuda.is_available() and not config.get('use_cpu_only', False) else 'cpu'
|
||||
config['device'] = device
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='STDEN项目训练和评估')
|
||||
parser.add_argument('--model_name', type=str, required=True,
|
||||
choices=['stde_gt', 'stde_wrs', 'stde_zgc'],
|
||||
help='模型名称,对应配置文件')
|
||||
parser.add_argument('--mode', type=str, default='train',
|
||||
choices=['train', 'eval'],
|
||||
help='运行模式:训练或评估')
|
||||
parser.add_argument('--config_dir', type=str, default='configs',
|
||||
help='配置文件目录')
|
||||
parser.add_argument('--use_cpu_only', action='store_true',
|
||||
help='仅使用CPU')
|
||||
parser.add_argument('--save_pred', action='store_true',
|
||||
help='保存预测结果(仅评估模式)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 构建配置文件路径
|
||||
config_path = Path(args.config_dir) / f"{args.model_name}.yaml"
|
||||
|
||||
if not config_path.exists():
|
||||
print(f"错误:配置文件 {config_path} 不存在")
|
||||
sys.exit(1)
|
||||
|
||||
# 加载配置
|
||||
config = load_config(config_path)
|
||||
config['use_cpu_only'] = args.use_cpu_only
|
||||
config['mode'] = args.mode
|
||||
|
||||
# 设置环境
|
||||
config = setup_environment(config)
|
||||
|
||||
# 设置日志
|
||||
logger = setup_logger(config)
|
||||
logger.info(f"开始运行 {args.model_name} 模型,模式:{args.mode}")
|
||||
logger.info(f"使用设备:{config['device']}")
|
||||
|
||||
# 加载图数据
|
||||
graph_pkl_filename = config['data']['graph_pkl_filename']
|
||||
adj_matrix = load_graph_data(graph_pkl_filename)
|
||||
config['adj_matrix'] = adj_matrix
|
||||
|
||||
# 创建数据加载器
|
||||
dataloader = STDENDataloader(config)
|
||||
|
||||
# 创建训练器
|
||||
trainer = STDENTrainer(config, dataloader)
|
||||
|
||||
# 根据模式执行
|
||||
if args.mode == 'train':
|
||||
trainer.train()
|
||||
else: # eval mode
|
||||
trainer.evaluate(save_predictions=args.save_pred)
|
||||
|
||||
logger.info(f"{args.model_name} 模型运行完成")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
训练器模块
|
||||
包含STDEN项目的训练和评估逻辑
|
||||
"""
|
||||
|
||||
from .stden_trainer import STDENTrainer
|
||||
|
||||
__all__ = ['STDENTrainer']
|
||||
|
|
@ -0,0 +1,383 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
STDEN训练器
|
||||
负责模型的训练、验证和评估
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import numpy as np
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from model.stden_model import STDENModel
|
||||
from lib.metrics import masked_mae_loss, masked_mape_loss, masked_rmse_loss
|
||||
|
||||
|
||||
class STDENTrainer:
|
||||
"""STDEN训练器主类"""
|
||||
|
||||
def __init__(self, config, dataloader):
|
||||
"""
|
||||
初始化训练器
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
dataloader: 数据加载器实例
|
||||
"""
|
||||
self.config = config
|
||||
self.dataloader = dataloader
|
||||
self.logger = logging.getLogger('STDEN')
|
||||
|
||||
# 设置设备
|
||||
self.device = torch.device(config['device'])
|
||||
|
||||
# 模型配置
|
||||
self.model_config = config['model']
|
||||
self.train_config = config['train']
|
||||
|
||||
# 创建模型
|
||||
self.model = self._create_model()
|
||||
|
||||
# 创建优化器和学习率调度器
|
||||
self.optimizer, self.scheduler = self._create_optimizer()
|
||||
|
||||
# 设置损失函数
|
||||
self.criterion = masked_mae_loss
|
||||
|
||||
# 创建检查点目录
|
||||
self.checkpoint_dir = self._create_checkpoint_dir()
|
||||
|
||||
# 创建TensorBoard写入器
|
||||
self.writer = self._create_tensorboard_writer()
|
||||
|
||||
# 训练状态
|
||||
self.current_epoch = 0
|
||||
self.best_val_loss = float('inf')
|
||||
self.patience_counter = 0
|
||||
|
||||
self.logger.info("训练器初始化完成")
|
||||
|
||||
def _create_model(self):
|
||||
"""创建STDEN模型"""
|
||||
# 获取邻接矩阵
|
||||
adj_matrix = self.config['adj_matrix']
|
||||
|
||||
# 创建模型
|
||||
model = STDENModel(
|
||||
adj_matrix=adj_matrix,
|
||||
logger=self.logger,
|
||||
**self.model_config
|
||||
)
|
||||
|
||||
# 移动到指定设备
|
||||
model = model.to(self.device)
|
||||
|
||||
self.logger.info(f"模型创建完成,参数数量: {sum(p.numel() for p in model.parameters())}")
|
||||
return model
|
||||
|
||||
def _create_optimizer(self):
|
||||
"""创建优化器和学习率调度器"""
|
||||
# 获取训练配置
|
||||
base_lr = self.train_config['base_lr']
|
||||
optimizer_name = self.train_config.get('optimizer', 'adam').lower()
|
||||
|
||||
# 创建优化器
|
||||
if optimizer_name == 'adam':
|
||||
optimizer = optim.Adam(self.model.parameters(), lr=base_lr)
|
||||
elif optimizer_name == 'sgd':
|
||||
optimizer = optim.SGD(self.model.parameters(), lr=base_lr)
|
||||
else:
|
||||
raise ValueError(f"不支持的优化器: {optimizer_name}")
|
||||
|
||||
# 创建学习率调度器
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer,
|
||||
mode='min',
|
||||
factor=self.train_config.get('lr_decay_ratio', 0.1),
|
||||
patience=self.train_config.get('patience', 20),
|
||||
min_lr=self.train_config.get('min_learning_rate', 1e-6),
|
||||
verbose=True
|
||||
)
|
||||
|
||||
return optimizer, scheduler
|
||||
|
||||
def _create_checkpoint_dir(self):
|
||||
"""创建检查点目录"""
|
||||
checkpoint_dir = Path("checkpoints") / f"experiment_{int(time.time())}"
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存配置
|
||||
config_path = checkpoint_dir / "config.json"
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
return checkpoint_dir
|
||||
|
||||
def _create_tensorboard_writer(self):
|
||||
"""创建TensorBoard写入器"""
|
||||
log_dir = Path("runs") / f"experiment_{int(time.time())}"
|
||||
return SummaryWriter(str(log_dir))
|
||||
|
||||
def train(self):
|
||||
"""训练模型"""
|
||||
self.logger.info("开始训练模型")
|
||||
|
||||
# 获取训练配置
|
||||
epochs = self.train_config['epochs']
|
||||
patience = self.train_config['patience']
|
||||
test_every_n_epochs = self.train_config.get('test_every_n_epochs', 5)
|
||||
|
||||
# 获取数据加载器
|
||||
train_loader = self.dataloader.train_loader
|
||||
val_loader = self.dataloader.val_loader
|
||||
|
||||
for epoch in range(epochs):
|
||||
self.current_epoch = epoch
|
||||
|
||||
# 训练一个epoch
|
||||
train_loss = self._train_epoch(train_loader)
|
||||
|
||||
# 验证
|
||||
val_loss = self._validate_epoch(val_loader)
|
||||
|
||||
# 更新学习率
|
||||
self.scheduler.step(val_loss)
|
||||
|
||||
# 记录到TensorBoard
|
||||
self.writer.add_scalar('Loss/Train', train_loss, epoch)
|
||||
self.writer.add_scalar('Loss/Validation', val_loss, epoch)
|
||||
self.writer.add_scalar('Learning_Rate', self.optimizer.param_groups[0]['lr'], epoch)
|
||||
|
||||
# 打印训练信息
|
||||
self.logger.info(
|
||||
f"Epoch {epoch+1}/{epochs} - "
|
||||
f"Train Loss: {train_loss:.6f}, "
|
||||
f"Val Loss: {val_loss:.6f}, "
|
||||
f"LR: {self.optimizer.param_groups[0]['lr']:.6f}"
|
||||
)
|
||||
|
||||
# 保存最佳模型
|
||||
if val_loss < self.best_val_loss:
|
||||
self.best_val_loss = val_loss
|
||||
self.patience_counter = 0
|
||||
self._save_checkpoint(epoch, is_best=True)
|
||||
self.logger.info(f"新的最佳验证损失: {val_loss:.6f}")
|
||||
else:
|
||||
self.patience_counter += 1
|
||||
|
||||
# 定期保存检查点
|
||||
if (epoch + 1) % 10 == 0:
|
||||
self._save_checkpoint(epoch, is_best=False)
|
||||
|
||||
# 定期测试
|
||||
if (epoch + 1) % test_every_n_epochs == 0:
|
||||
test_metrics = self._test_epoch()
|
||||
self.logger.info(f"测试指标: {test_metrics}")
|
||||
|
||||
# 早停检查
|
||||
if self.patience_counter >= patience:
|
||||
self.logger.info(f"早停触发,在epoch {epoch+1}")
|
||||
break
|
||||
|
||||
# 训练完成
|
||||
self.logger.info("训练完成")
|
||||
self.writer.close()
|
||||
|
||||
# 最终测试
|
||||
final_test_metrics = self._test_epoch()
|
||||
self.logger.info(f"最终测试指标: {final_test_metrics}")
|
||||
|
||||
def _train_epoch(self, train_loader):
|
||||
"""训练一个epoch"""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for batch_idx, (x, y) in enumerate(train_loader):
|
||||
# 移动数据到设备
|
||||
x = x.to(self.device)
|
||||
y = y.to(self.device)
|
||||
|
||||
# 前向传播
|
||||
self.optimizer.zero_grad()
|
||||
output = self.model(x)
|
||||
|
||||
# 计算损失
|
||||
loss = self.criterion(output, y)
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
|
||||
# 梯度裁剪
|
||||
max_grad_norm = self.train_config.get('max_grad_norm', 5.0)
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
|
||||
|
||||
# 更新参数
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
return total_loss / num_batches
|
||||
|
||||
def _validate_epoch(self, val_loader):
|
||||
"""验证一个epoch"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for x, y in val_loader:
|
||||
x = x.to(self.device)
|
||||
y = y.to(self.device)
|
||||
|
||||
output = self.model(x)
|
||||
loss = self.criterion(output, y)
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
return total_loss / num_batches
|
||||
|
||||
def _test_epoch(self):
|
||||
"""测试模型"""
|
||||
self.model.eval()
|
||||
test_loader = self.dataloader.test_loader
|
||||
|
||||
total_mae = 0.0
|
||||
total_mape = 0.0
|
||||
total_rmse = 0.0
|
||||
num_batches = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for x, y in test_loader:
|
||||
x = x.to(self.device)
|
||||
y = y.to(self.device)
|
||||
|
||||
output = self.model(x)
|
||||
|
||||
# 计算各种指标
|
||||
mae = masked_mae_loss(output, y)
|
||||
mape = masked_mape_loss(output, y)
|
||||
rmse = masked_rmse_loss(output, y)
|
||||
|
||||
total_mae += mae.item()
|
||||
total_mape += mape.item()
|
||||
total_rmse += rmse.item()
|
||||
num_batches += 1
|
||||
|
||||
metrics = {
|
||||
'MAE': total_mae / num_batches,
|
||||
'MAPE': total_mape / num_batches,
|
||||
'RMSE': total_rmse / num_batches
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
def _save_checkpoint(self, epoch, is_best=False):
|
||||
"""保存检查点"""
|
||||
checkpoint = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||
'best_val_loss': self.best_val_loss,
|
||||
'config': self.config
|
||||
}
|
||||
|
||||
# 保存最新检查点
|
||||
latest_path = self.checkpoint_dir / "latest.pth"
|
||||
torch.save(checkpoint, latest_path)
|
||||
|
||||
# 保存最佳检查点
|
||||
if is_best:
|
||||
best_path = self.checkpoint_dir / "best.pth"
|
||||
torch.save(checkpoint, best_path)
|
||||
|
||||
# 保存特定epoch的检查点
|
||||
epoch_path = self.checkpoint_dir / f"epoch_{epoch+1}.pth"
|
||||
torch.save(checkpoint, epoch_path)
|
||||
|
||||
self.logger.info(f"检查点已保存: {epoch_path}")
|
||||
|
||||
def load_checkpoint(self, checkpoint_path):
|
||||
"""加载检查点"""
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
self.current_epoch = checkpoint['epoch']
|
||||
self.best_val_loss = checkpoint['best_val_loss']
|
||||
|
||||
self.logger.info(f"检查点已加载: {checkpoint_path}")
|
||||
|
||||
def evaluate(self, save_predictions=False):
|
||||
"""评估模型"""
|
||||
self.logger.info("开始评估模型")
|
||||
|
||||
# 加载最佳模型
|
||||
best_checkpoint_path = self.checkpoint_dir / "best.pth"
|
||||
if best_checkpoint_path.exists():
|
||||
self.load_checkpoint(best_checkpoint_path)
|
||||
else:
|
||||
self.logger.warning("未找到最佳检查点,使用当前模型")
|
||||
|
||||
# 测试
|
||||
test_metrics = self._test_epoch()
|
||||
|
||||
# 打印结果
|
||||
self.logger.info("评估结果:")
|
||||
for metric_name, metric_value in test_metrics.items():
|
||||
self.logger.info(f" {metric_name}: {metric_value:.6f}")
|
||||
|
||||
# 保存预测结果
|
||||
if save_predictions:
|
||||
self._save_predictions()
|
||||
|
||||
return test_metrics
|
||||
|
||||
def _save_predictions(self):
|
||||
"""保存预测结果"""
|
||||
self.model.eval()
|
||||
test_loader = self.dataloader.test_loader
|
||||
|
||||
predictions = []
|
||||
targets = []
|
||||
|
||||
with torch.no_grad():
|
||||
for x, y in test_loader:
|
||||
x = x.to(self.device)
|
||||
output = self.model(x)
|
||||
|
||||
# 反标准化
|
||||
scaler = self.dataloader.get_scaler()
|
||||
output_denorm = scaler.inverse_transform(output.cpu().numpy())
|
||||
y_denorm = scaler.inverse_transform(y.numpy())
|
||||
|
||||
predictions.append(output_denorm)
|
||||
targets.append(y_denorm)
|
||||
|
||||
# 合并所有批次
|
||||
predictions = np.concatenate(predictions, axis=0)
|
||||
targets = np.concatenate(targets, axis=0)
|
||||
|
||||
# 保存到文件
|
||||
results_dir = self.checkpoint_dir / "results"
|
||||
results_dir.mkdir(exist_ok=True)
|
||||
|
||||
np.save(results_dir / "predictions.npy", predictions)
|
||||
np.save(results_dir / "targets.npy", targets)
|
||||
|
||||
self.logger.info(f"预测结果已保存到: {results_dir}")
|
||||
|
||||
def __del__(self):
|
||||
"""析构函数,确保资源被正确释放"""
|
||||
if hasattr(self, 'writer'):
|
||||
self.writer.close()
|
||||
Loading…
Reference in New Issue