86 lines
1.7 KiB
YAML
86 lines
1.7 KiB
YAML
# STDE_ZGC模型配置文件
|
||
# 用于ZGC传感器图数据的时空扩散方程网络模型
|
||
|
||
# 基础配置
|
||
model_name: "stde_zgc"
|
||
random_seed: 2021
|
||
log_level: "INFO"
|
||
log_base_dir: "logs/ZGC"
|
||
|
||
# 数据配置
|
||
data:
|
||
# 数据集目录
|
||
dataset_dir: "data/ZGC"
|
||
# 批处理大小
|
||
batch_size: 32
|
||
# 验证集批处理大小
|
||
val_batch_size: 32
|
||
# 传感器图邻接矩阵文件
|
||
graph_pkl_filename: "data/sensor_graph/adj_ZGC.npy"
|
||
|
||
# 模型配置
|
||
model:
|
||
# 输入序列长度
|
||
seq_len: 12
|
||
# 预测时间步数
|
||
horizon: 12
|
||
# 输入特征维度
|
||
input_dim: 1
|
||
# 输出特征维度
|
||
output_dim: 1
|
||
# 潜在空间维度
|
||
latent_dim: 4
|
||
# 轨迹采样数量
|
||
n_traj_samples: 3
|
||
# ODE求解方法
|
||
ode_method: "dopri5"
|
||
# ODE求解器绝对误差容差
|
||
odeint_atol: 0.00001
|
||
# ODE求解器相对误差容差
|
||
odeint_rtol: 0.00001
|
||
# RNN隐藏单元数量
|
||
rnn_units: 64
|
||
# RNN层数
|
||
num_rnn_layers: 1
|
||
# 图卷积步数
|
||
gcn_step: 2
|
||
# 滤波器类型 (default/unkP/IncP)
|
||
filter_type: "default"
|
||
# 循环神经网络类型
|
||
recg_type: "gru"
|
||
# 是否保存潜在表示
|
||
save_latent: false
|
||
# 是否记录函数评估次数
|
||
nfe: false
|
||
# L1正则化衰减
|
||
l1_decay: 0
|
||
|
||
# 训练配置
|
||
train:
|
||
# 基础学习率
|
||
base_lr: 0.01
|
||
# Dropout比率
|
||
dropout: 0
|
||
# 加载的检查点epoch
|
||
load: 0
|
||
# 当前训练epoch
|
||
epoch: 0
|
||
# 总训练epoch数
|
||
epochs: 100
|
||
# 收敛阈值
|
||
epsilon: 1.0e-3
|
||
# 学习率衰减比率
|
||
lr_decay_ratio: 0.1
|
||
# 最大梯度范数
|
||
max_grad_norm: 5
|
||
# 最小学习率
|
||
min_learning_rate: 2.0e-06
|
||
# 优化器类型
|
||
optimizer: "adam"
|
||
# 早停耐心值
|
||
patience: 20
|
||
# 学习率衰减步数
|
||
steps: [20, 30, 40, 50]
|
||
# 测试频率(每N个epoch测试一次)
|
||
test_every_n_epochs: 5
|