Project-I/configs/stde_wrs.yaml

86 lines
1.7 KiB
YAML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# STDE_WRS模型配置文件
# 用于WRS传感器图数据的时空扩散方程网络模型
# 基础配置
model_name: "stde_wrs"
random_seed: 2021
log_level: "INFO"
log_base_dir: "logs/WRS"
# 数据配置
data:
# 数据集目录
dataset_dir: "data/WRS"
# 批处理大小
batch_size: 32
# 验证集批处理大小
val_batch_size: 32
# 传感器图邻接矩阵文件
graph_pkl_filename: "data/sensor_graph/adj_WRS.npy"
# 模型配置
model:
# 输入序列长度
seq_len: 12
# 预测时间步数
horizon: 12
# 输入特征维度
input_dim: 1
# 输出特征维度
output_dim: 1
# 潜在空间维度
latent_dim: 4
# 轨迹采样数量
n_traj_samples: 3
# ODE求解方法
ode_method: "dopri5"
# ODE求解器绝对误差容差
odeint_atol: 0.00001
# ODE求解器相对误差容差
odeint_rtol: 0.00001
# RNN隐藏单元数量
rnn_units: 64
# RNN层数
num_rnn_layers: 1
# 图卷积步数
gcn_step: 2
# 滤波器类型 (default/unkP/IncP)
filter_type: "default"
# 循环神经网络类型
recg_type: "gru"
# 是否保存潜在表示
save_latent: false
# 是否记录函数评估次数
nfe: false
# L1正则化衰减
l1_decay: 0
# 训练配置
train:
# 基础学习率
base_lr: 0.01
# Dropout比率
dropout: 0
# 加载的检查点epoch
load: 0
# 当前训练epoch
epoch: 0
# 总训练epoch数
epochs: 100
# 收敛阈值
epsilon: 1.0e-3
# 学习率衰减比率
lr_decay_ratio: 0.1
# 最大梯度范数
max_grad_norm: 5
# 最小学习率
min_learning_rate: 2.0e-06
# 优化器类型
optimizer: "adam"
# 早停耐心值
patience: 20
# 学习率衰减步数
steps: [20, 30, 40, 50]
# 测试频率每N个epoch测试一次
test_every_n_epochs: 5