适配GraphWaveNet
This commit is contained in:
parent
140ead3975
commit
a9313390ac
|
|
@ -6,40 +6,41 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 16
|
batch_size: 64
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
input_dim: 6
|
input_dim: 6
|
||||||
lag: 24
|
lag: 24
|
||||||
normalizer: std
|
normalizer: std
|
||||||
num_nodes: 12
|
num_nodes: 35
|
||||||
steps_per_day: 24
|
steps_per_day: 24
|
||||||
test_ratio: 0.2
|
test_ratio: 0.2
|
||||||
val_ratio: 0.2
|
val_ratio: 0.2
|
||||||
|
|
||||||
model:
|
model:
|
||||||
addaptadj: true
|
addaptadj: true
|
||||||
|
apt_size: 10
|
||||||
aptinit: null
|
aptinit: null
|
||||||
batch_size: 16
|
batch_size: 64
|
||||||
blocks: 4
|
blocks: 4
|
||||||
dilation_channels: 32
|
dilation_channels: 32
|
||||||
dropout: 0.3
|
dropout: 0.3
|
||||||
|
do_graph_conv: True
|
||||||
end_channels: 512
|
end_channels: 512
|
||||||
gcn_bool: true
|
gcn_bool: true
|
||||||
in_dim: 2
|
in_dim: 1
|
||||||
input_dim: 6
|
input_dim: 1
|
||||||
kernel_size: 2
|
kernel_size: 2
|
||||||
layers: 2
|
layers: 4
|
||||||
num_nodes: 12
|
num_nodes: 35
|
||||||
out_dim: 12
|
out_dim: 24
|
||||||
output_dim: 6
|
|
||||||
residual_channels: 32
|
residual_channels: 32
|
||||||
skip_channels: 256
|
skip_channels: 256
|
||||||
supports: null
|
supports: null
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 64
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -54,7 +55,7 @@ train:
|
||||||
mae_thresh: 0.0
|
mae_thresh: 0.0
|
||||||
mape_thresh: 0.0
|
mape_thresh: 0.0
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 6
|
output_dim: 1
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
|
|
|
||||||
|
|
@ -20,24 +20,26 @@ data:
|
||||||
|
|
||||||
model:
|
model:
|
||||||
addaptadj: true
|
addaptadj: true
|
||||||
|
apt_size: 10
|
||||||
aptinit: null
|
aptinit: null
|
||||||
batch_size: 32
|
batch_size: 16
|
||||||
blocks: 4
|
blocks: 4
|
||||||
dilation_channels: 32
|
dilation_channels: 32
|
||||||
dropout: 0.3
|
dropout: 0.3
|
||||||
|
do_graph_conv: True
|
||||||
end_channels: 512
|
end_channels: 512
|
||||||
gcn_bool: true
|
gcn_bool: true
|
||||||
in_dim: 2
|
in_dim: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
kernel_size: 2
|
kernel_size: 2
|
||||||
layers: 2
|
layers: 4
|
||||||
num_nodes: 1024
|
num_nodes: 1024
|
||||||
out_dim: 12
|
out_dim: 24
|
||||||
output_dim: 1
|
|
||||||
residual_channels: 32
|
residual_channels: 32
|
||||||
skip_channels: 256
|
skip_channels: 256
|
||||||
supports: null
|
supports: null
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 32
|
batch_size: 32
|
||||||
debug: false
|
debug: false
|
||||||
|
|
|
||||||
|
|
@ -20,20 +20,21 @@ data:
|
||||||
|
|
||||||
model:
|
model:
|
||||||
addaptadj: true
|
addaptadj: true
|
||||||
|
apt_size: 10
|
||||||
aptinit: null
|
aptinit: null
|
||||||
batch_size: 32
|
batch_size: 32
|
||||||
blocks: 4
|
blocks: 4
|
||||||
dilation_channels: 32
|
dilation_channels: 32
|
||||||
dropout: 0.3
|
dropout: 0.3
|
||||||
|
do_graph_conv: True
|
||||||
end_channels: 512
|
end_channels: 512
|
||||||
gcn_bool: true
|
gcn_bool: true
|
||||||
in_dim: 2
|
in_dim: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
kernel_size: 2
|
kernel_size: 2
|
||||||
layers: 2
|
layers: 4
|
||||||
num_nodes: 1024
|
num_nodes: 1024
|
||||||
out_dim: 12
|
out_dim: 24
|
||||||
output_dim: 1
|
|
||||||
residual_channels: 32
|
residual_channels: 32
|
||||||
skip_channels: 256
|
skip_channels: 256
|
||||||
supports: null
|
supports: null
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 16
|
batch_size: 64
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -20,26 +20,27 @@ data:
|
||||||
|
|
||||||
model:
|
model:
|
||||||
addaptadj: true
|
addaptadj: true
|
||||||
|
apt_size: 10
|
||||||
aptinit: null
|
aptinit: null
|
||||||
batch_size: 16
|
batch_size: 64
|
||||||
blocks: 4
|
blocks: 4
|
||||||
dilation_channels: 32
|
dilation_channels: 32
|
||||||
dropout: 0.3
|
dropout: 0.3
|
||||||
|
do_graph_conv: True
|
||||||
end_channels: 512
|
end_channels: 512
|
||||||
gcn_bool: true
|
gcn_bool: true
|
||||||
in_dim: 2
|
in_dim: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
kernel_size: 2
|
kernel_size: 2
|
||||||
layers: 2
|
layers: 4
|
||||||
num_nodes: 207
|
num_nodes: 207
|
||||||
out_dim: 12
|
out_dim: 24
|
||||||
output_dim: 1
|
|
||||||
residual_channels: 32
|
residual_channels: 32
|
||||||
skip_channels: 256
|
skip_channels: 256
|
||||||
supports: null
|
supports: null
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 64
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
|
||||||
|
|
@ -20,20 +20,21 @@ data:
|
||||||
|
|
||||||
model:
|
model:
|
||||||
addaptadj: true
|
addaptadj: true
|
||||||
|
apt_size: 10
|
||||||
aptinit: null
|
aptinit: null
|
||||||
batch_size: 32
|
batch_size: 32
|
||||||
blocks: 4
|
blocks: 4
|
||||||
dilation_channels: 32
|
dilation_channels: 32
|
||||||
dropout: 0.3
|
dropout: 0.3
|
||||||
|
do_graph_conv: True
|
||||||
end_channels: 512
|
end_channels: 512
|
||||||
gcn_bool: true
|
gcn_bool: true
|
||||||
in_dim: 2
|
in_dim: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
kernel_size: 2
|
kernel_size: 2
|
||||||
layers: 2
|
layers: 4
|
||||||
num_nodes: 128
|
num_nodes: 128
|
||||||
out_dim: 12
|
out_dim: 24
|
||||||
output_dim: 1
|
|
||||||
residual_channels: 32
|
residual_channels: 32
|
||||||
skip_channels: 256
|
skip_channels: 256
|
||||||
supports: null
|
supports: null
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 32
|
batch_size: 16
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -20,26 +20,27 @@ data:
|
||||||
|
|
||||||
model:
|
model:
|
||||||
addaptadj: true
|
addaptadj: true
|
||||||
|
apt_size: 10
|
||||||
aptinit: null
|
aptinit: null
|
||||||
batch_size: 32
|
batch_size: 16
|
||||||
blocks: 4
|
blocks: 4
|
||||||
dilation_channels: 32
|
dilation_channels: 32
|
||||||
dropout: 0.3
|
dropout: 0.3
|
||||||
|
do_graph_conv: True
|
||||||
end_channels: 512
|
end_channels: 512
|
||||||
gcn_bool: true
|
gcn_bool: true
|
||||||
in_dim: 2
|
in_dim: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
kernel_size: 2
|
kernel_size: 2
|
||||||
layers: 2
|
layers: 4
|
||||||
num_nodes: 128
|
num_nodes: 128
|
||||||
out_dim: 12
|
out_dim: 24
|
||||||
output_dim: 1
|
|
||||||
residual_channels: 32
|
residual_channels: 32
|
||||||
skip_channels: 256
|
skip_channels: 256
|
||||||
supports: null
|
supports: null
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 32
|
batch_size: 16
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
basic:
|
||||||
|
dataset: PEMS-BAY
|
||||||
|
device: cuda:0
|
||||||
|
mode: train
|
||||||
|
model: GWN
|
||||||
|
seed: 2023
|
||||||
|
|
||||||
|
data:
|
||||||
|
batch_size: 64
|
||||||
|
column_wise: false
|
||||||
|
days_per_week: 7
|
||||||
|
horizon: 24
|
||||||
|
input_dim: 1
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 325
|
||||||
|
steps_per_day: 288
|
||||||
|
test_ratio: 0.2
|
||||||
|
val_ratio: 0.2
|
||||||
|
|
||||||
|
model:
|
||||||
|
addaptadj: true
|
||||||
|
apt_size: 10
|
||||||
|
aptinit: null
|
||||||
|
batch_size: 64
|
||||||
|
blocks: 4
|
||||||
|
dilation_channels: 32
|
||||||
|
dropout: 0.3
|
||||||
|
do_graph_conv: True
|
||||||
|
end_channels: 512
|
||||||
|
gcn_bool: true
|
||||||
|
in_dim: 1
|
||||||
|
input_dim: 1
|
||||||
|
kernel_size: 2
|
||||||
|
layers: 4
|
||||||
|
num_nodes: 325
|
||||||
|
out_dim: 24
|
||||||
|
residual_channels: 32
|
||||||
|
skip_channels: 256
|
||||||
|
supports: null
|
||||||
|
|
||||||
|
train:
|
||||||
|
batch_size: 64
|
||||||
|
debug: false
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 300
|
||||||
|
grad_norm: false
|
||||||
|
log_step: 1000
|
||||||
|
loss_func: mae
|
||||||
|
lr_decay: false
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: 5,20,40,70
|
||||||
|
lr_init: 0.003
|
||||||
|
mae_thresh: 0.0
|
||||||
|
mape_thresh: 0.0
|
||||||
|
max_grad_norm: 5
|
||||||
|
output_dim: 1
|
||||||
|
plot: false
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
|
@ -27,7 +27,7 @@ model:
|
||||||
dropout: 0.3
|
dropout: 0.3
|
||||||
end_channels: 512
|
end_channels: 512
|
||||||
gcn_bool: true
|
gcn_bool: true
|
||||||
in_dim: 2
|
in_dim: 3
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
kernel_size: 2
|
kernel_size: 2
|
||||||
layers: 2
|
layers: 2
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ model:
|
||||||
dropout: 0.3
|
dropout: 0.3
|
||||||
end_channels: 512
|
end_channels: 512
|
||||||
gcn_bool: true
|
gcn_bool: true
|
||||||
in_dim: 2
|
in_dim: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
kernel_size: 2
|
kernel_size: 2
|
||||||
layers: 2
|
layers: 2
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ model:
|
||||||
dropout: 0.3
|
dropout: 0.3
|
||||||
end_channels: 512
|
end_channels: 512
|
||||||
gcn_bool: true
|
gcn_bool: true
|
||||||
in_dim: 2
|
in_dim: 3
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
kernel_size: 2
|
kernel_size: 2
|
||||||
layers: 2
|
layers: 2
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ model:
|
||||||
dropout: 0.3
|
dropout: 0.3
|
||||||
end_channels: 512
|
end_channels: 512
|
||||||
gcn_bool: true
|
gcn_bool: true
|
||||||
in_dim: 2
|
in_dim: 3
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
kernel_size: 2
|
kernel_size: 2
|
||||||
layers: 2
|
layers: 2
|
||||||
|
|
|
||||||
|
|
@ -20,20 +20,21 @@ data:
|
||||||
|
|
||||||
model:
|
model:
|
||||||
addaptadj: true
|
addaptadj: true
|
||||||
|
apt_size: 10
|
||||||
aptinit: null
|
aptinit: null
|
||||||
batch_size: 64
|
batch_size: 32
|
||||||
blocks: 4
|
blocks: 4
|
||||||
dilation_channels: 32
|
dilation_channels: 32
|
||||||
dropout: 0.3
|
dropout: 0.3
|
||||||
|
do_graph_conv: True
|
||||||
end_channels: 512
|
end_channels: 512
|
||||||
gcn_bool: true
|
gcn_bool: true
|
||||||
in_dim: 2
|
in_dim: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
kernel_size: 2
|
kernel_size: 2
|
||||||
layers: 2
|
layers: 4
|
||||||
num_nodes: 137
|
num_nodes: 137
|
||||||
out_dim: 12
|
out_dim: 24
|
||||||
output_dim: 1
|
|
||||||
residual_channels: 32
|
residual_channels: 32
|
||||||
skip_channels: 256
|
skip_channels: 256
|
||||||
supports: null
|
supports: null
|
||||||
|
|
|
||||||
234
config/tmp.py
234
config/tmp.py
|
|
@ -1,234 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
import os
|
|
||||||
from collections import defaultdict
|
|
||||||
from ruamel.yaml import YAML
|
|
||||||
from ruamel.yaml.comments import CommentedMap
|
|
||||||
|
|
||||||
yaml = YAML()
|
|
||||||
yaml.preserve_quotes = True
|
|
||||||
yaml.indent(mapping=2, sequence=4, offset=2)
|
|
||||||
|
|
||||||
# 允许的 data keys
|
|
||||||
DATA_ALLOWED_KEYS = {
|
|
||||||
"lag",
|
|
||||||
"horizon",
|
|
||||||
"num_nodes",
|
|
||||||
"steps_per_day",
|
|
||||||
"days_per_week",
|
|
||||||
"test_ratio",
|
|
||||||
"val_ratio",
|
|
||||||
"batch_size",
|
|
||||||
"input_dim",
|
|
||||||
"column_wise",
|
|
||||||
"normalizer",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 全局默认值
|
|
||||||
GLOBAL_DEFAULTS = {
|
|
||||||
"lag": 24,
|
|
||||||
"horizon": 24,
|
|
||||||
"num_nodes": 1,
|
|
||||||
"steps_per_day": 24,
|
|
||||||
"days_per_week": 7,
|
|
||||||
"test_ratio": 0.2,
|
|
||||||
"val_ratio": 0.2,
|
|
||||||
"batch_size": 16,
|
|
||||||
"input_dim": 1,
|
|
||||||
"column_wise": False,
|
|
||||||
"normalizer": "std",
|
|
||||||
}
|
|
||||||
|
|
||||||
# train全局默认值
|
|
||||||
GLOBAL_TRAIN_DEFAULTS = {
|
|
||||||
"output_dim": 1
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_yaml(path):
|
|
||||||
try:
|
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
|
||||||
return yaml.load(f)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def collect_dataset_defaults(base="."):
|
|
||||||
"""
|
|
||||||
收集每个数据集 data 的 key 默认值,以及 train.output_dim 默认值
|
|
||||||
"""
|
|
||||||
data_defaults = defaultdict(dict)
|
|
||||||
train_output_defaults = dict()
|
|
||||||
|
|
||||||
for root, _, files in os.walk(base):
|
|
||||||
for name in files:
|
|
||||||
if not (name.endswith(".yaml") or name.endswith(".yml")):
|
|
||||||
continue
|
|
||||||
path = os.path.join(root, name)
|
|
||||||
cm = load_yaml(path)
|
|
||||||
if not isinstance(cm, CommentedMap):
|
|
||||||
continue
|
|
||||||
basic = cm.get("basic")
|
|
||||||
if not isinstance(basic, dict):
|
|
||||||
continue
|
|
||||||
dataset = basic.get("dataset")
|
|
||||||
if dataset is None:
|
|
||||||
continue
|
|
||||||
ds = str(dataset)
|
|
||||||
|
|
||||||
# data 默认值
|
|
||||||
data_sec = cm.get("data")
|
|
||||||
if isinstance(data_sec, dict):
|
|
||||||
for key in DATA_ALLOWED_KEYS:
|
|
||||||
if key not in data_defaults[ds] and key in data_sec and data_sec[key] is not None:
|
|
||||||
data_defaults[ds][key] = data_sec[key]
|
|
||||||
|
|
||||||
# train.output_dim 默认值
|
|
||||||
train_sec = cm.get("train")
|
|
||||||
if isinstance(train_sec, dict):
|
|
||||||
val = train_sec.get("output_dim")
|
|
||||||
if val is not None and ds not in train_output_defaults:
|
|
||||||
train_output_defaults[ds] = val
|
|
||||||
|
|
||||||
return data_defaults, train_output_defaults
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_basic_seed(cm: CommentedMap, path: str):
|
|
||||||
if "basic" not in cm or not isinstance(cm["basic"], dict):
|
|
||||||
cm["basic"] = CommentedMap()
|
|
||||||
basic = cm["basic"]
|
|
||||||
if "seed" not in basic:
|
|
||||||
basic["seed"] = 2023
|
|
||||||
print(f"[ADD] {path}: basic.seed = 2023")
|
|
||||||
|
|
||||||
|
|
||||||
def fill_data_defaults(cm: CommentedMap, data_defaults: dict, path: str):
|
|
||||||
if "data" not in cm or not isinstance(cm["data"], dict):
|
|
||||||
cm["data"] = CommentedMap()
|
|
||||||
data_sec = cm["data"]
|
|
||||||
|
|
||||||
basic = cm.get("basic", {})
|
|
||||||
dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None
|
|
||||||
|
|
||||||
for key in sorted(DATA_ALLOWED_KEYS):
|
|
||||||
if key in data_sec and data_sec[key] is not None:
|
|
||||||
continue
|
|
||||||
if dataset and dataset in data_defaults and key in data_defaults[dataset]:
|
|
||||||
chosen = data_defaults[dataset][key]
|
|
||||||
src = f"default_from_dataset[{dataset}]"
|
|
||||||
else:
|
|
||||||
chosen = GLOBAL_DEFAULTS[key]
|
|
||||||
src = "GLOBAL_DEFAULTS"
|
|
||||||
data_sec[key] = chosen
|
|
||||||
print(f"[FILL] {path}: data.{key} <- {src} ({repr(chosen)})")
|
|
||||||
|
|
||||||
|
|
||||||
def merge_test_log_into_train(cm: CommentedMap, path: str):
|
|
||||||
"""
|
|
||||||
将 test 和 log 的 key 合并到 train,并删除 test 和 log
|
|
||||||
同时确保 train.debug 存在
|
|
||||||
"""
|
|
||||||
train_sec = cm.setdefault("train", CommentedMap())
|
|
||||||
|
|
||||||
for section in ["test", "log"]:
|
|
||||||
if section in cm and isinstance(cm[section], dict):
|
|
||||||
for k, v in cm[section].items():
|
|
||||||
if k not in train_sec:
|
|
||||||
train_sec[k] = v
|
|
||||||
print(f"[MERGE] {path}: train.{k} <- {section}.{k} ({repr(v)})")
|
|
||||||
del cm[section]
|
|
||||||
print(f"[DEL] {path}: deleted section '{section}'")
|
|
||||||
|
|
||||||
# train.debug
|
|
||||||
if "debug" not in train_sec:
|
|
||||||
train_sec["debug"] = False
|
|
||||||
print(f"[ADD] {path}: train.debug = False")
|
|
||||||
|
|
||||||
|
|
||||||
def fill_train_output_dim(cm: CommentedMap, train_output_defaults: dict, path: str):
|
|
||||||
train_sec = cm.setdefault("train", CommentedMap())
|
|
||||||
if "output_dim" not in train_sec or train_sec["output_dim"] is None:
|
|
||||||
basic = cm.get("basic", {})
|
|
||||||
dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None
|
|
||||||
if dataset and dataset in train_output_defaults:
|
|
||||||
val = train_output_defaults[dataset]
|
|
||||||
src = f"default_from_dataset[{dataset}]"
|
|
||||||
else:
|
|
||||||
val = GLOBAL_TRAIN_DEFAULTS["output_dim"]
|
|
||||||
src = "GLOBAL_TRAIN_DEFAULTS"
|
|
||||||
train_sec["output_dim"] = val
|
|
||||||
print(f"[FILL] {path}: train.output_dim <- {src} ({val})")
|
|
||||||
|
|
||||||
|
|
||||||
def sync_train_batch_size(cm: CommentedMap, path: str):
|
|
||||||
"""
|
|
||||||
如果 train.batch_size 与 data.batch_size 不一致,以 data 为准
|
|
||||||
"""
|
|
||||||
data_sec = cm.get("data", {})
|
|
||||||
train_sec = cm.get("train", {})
|
|
||||||
data_bs = data_sec.get("batch_size")
|
|
||||||
train_bs = train_sec.get("batch_size")
|
|
||||||
|
|
||||||
if data_bs is not None and train_bs != data_bs:
|
|
||||||
train_sec["batch_size"] = data_bs
|
|
||||||
print(f"[SYNC] {path}: train.batch_size corrected to match data.batch_size ({data_bs})")
|
|
||||||
|
|
||||||
|
|
||||||
def sort_subkeys_and_insert_blanklines(cm: CommentedMap):
|
|
||||||
for sec in list(cm.keys()):
|
|
||||||
if isinstance(cm[sec], dict):
|
|
||||||
sorted_cm = CommentedMap()
|
|
||||||
for k in sorted(cm[sec].keys()):
|
|
||||||
sorted_cm[k] = cm[sec][k]
|
|
||||||
cm[sec] = sorted_cm
|
|
||||||
|
|
||||||
keys = list(cm.keys())
|
|
||||||
for i, k in enumerate(keys):
|
|
||||||
if i == 0:
|
|
||||||
try:
|
|
||||||
cm.yaml_set_comment_before_after_key(k, before=None)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
cm.yaml_set_comment_before_after_key(k, before="\n")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def process_all(base="."):
|
|
||||||
print(">> Collecting dataset defaults ...")
|
|
||||||
data_defaults, train_output_defaults = collect_dataset_defaults(base)
|
|
||||||
print(">> Collected data defaults per dataset:")
|
|
||||||
for ds, kv in data_defaults.items():
|
|
||||||
print(f" - {ds}: {kv}")
|
|
||||||
print(">> Collected train.output_dim defaults per dataset:")
|
|
||||||
for ds, val in train_output_defaults.items():
|
|
||||||
print(f" - {ds}: output_dim = {val}")
|
|
||||||
|
|
||||||
for root, _, files in os.walk(base):
|
|
||||||
for name in files:
|
|
||||||
if not (name.endswith(".yaml") or name.endswith(".yml")):
|
|
||||||
continue
|
|
||||||
path = os.path.join(root, name)
|
|
||||||
cm = load_yaml(path)
|
|
||||||
if not isinstance(cm, CommentedMap):
|
|
||||||
print(f"[SKIP] {path}: top-level not mapping or load failed")
|
|
||||||
continue
|
|
||||||
|
|
||||||
ensure_basic_seed(cm, path)
|
|
||||||
fill_data_defaults(cm, data_defaults, path)
|
|
||||||
merge_test_log_into_train(cm, path)
|
|
||||||
fill_train_output_dim(cm, train_output_defaults, path)
|
|
||||||
sync_train_batch_size(cm, path) # <-- 新增逻辑
|
|
||||||
sort_subkeys_and_insert_blanklines(cm)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(path, "w", encoding="utf-8") as f:
|
|
||||||
yaml.dump(cm, f)
|
|
||||||
print(f"[OK] Written: {path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[ERROR] Write failed {path}: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
process_all(".")
|
|
||||||
|
|
@ -1,53 +1,35 @@
|
||||||
import torch, torch.nn as nn, torch.nn.functional as F
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import BatchNorm2d, Conv1d, Conv2d, ModuleList, Parameter
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
def nconv(x, A):
|
||||||
|
"""Multiply x by adjacency matrix along source node axis"""
|
||||||
|
return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous()
|
||||||
|
|
||||||
|
|
||||||
class nconv(nn.Module):
|
class GraphConvNet(nn.Module):
|
||||||
"""
|
|
||||||
图卷积操作的实现类
|
|
||||||
使用einsum进行矩阵运算,实现图卷积操作
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, x, A):
|
|
||||||
return torch.einsum("ncvl,vw->ncwl", (x, A)).contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
class linear(nn.Module):
|
|
||||||
"""
|
|
||||||
线性变换层
|
|
||||||
使用1x1卷积实现线性变换
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, c_in, c_out):
|
|
||||||
super().__init__()
|
|
||||||
self.mlp = nn.Conv2d(c_in, c_out, 1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.mlp(x)
|
|
||||||
|
|
||||||
|
|
||||||
class gcn(nn.Module):
|
|
||||||
"""
|
|
||||||
图卷积网络层
|
|
||||||
实现高阶图卷积操作,支持多阶邻接矩阵
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
|
def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.nconv = nconv()
|
|
||||||
c_in = (order * support_len + 1) * c_in
|
c_in = (order * support_len + 1) * c_in
|
||||||
self.mlp, self.dropout, self.order = linear(c_in, c_out), dropout, order
|
self.final_conv = Conv2d(c_in, c_out, (1, 1), padding=(0, 0), stride=(1, 1), bias=True)
|
||||||
|
self.dropout = dropout
|
||||||
|
self.order = order
|
||||||
|
|
||||||
def forward(self, x, support):
|
def forward(self, x, support: list):
|
||||||
out = [x]
|
out = [x]
|
||||||
for a in support:
|
for a in support:
|
||||||
x1 = self.nconv(x, a)
|
x1 = nconv(x, a)
|
||||||
out.append(x1)
|
out.append(x1)
|
||||||
for _ in range(2, self.order + 1):
|
for k in range(2, self.order + 1):
|
||||||
x1 = self.nconv(x1, a)
|
x2 = nconv(x1, a)
|
||||||
out.append(x1)
|
out.append(x2)
|
||||||
return F.dropout(
|
x1 = x2
|
||||||
self.mlp(torch.cat(out, dim=1)), self.dropout, training=self.training
|
|
||||||
)
|
h = torch.cat(out, dim=1)
|
||||||
|
h = self.final_conv(h)
|
||||||
|
h = F.dropout(h, self.dropout, training=self.training)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
class gwnet(nn.Module):
|
class gwnet(nn.Module):
|
||||||
|
|
@ -59,126 +41,121 @@ class gwnet(nn.Module):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 初始化基本参数
|
# 初始化基本参数
|
||||||
self.dropout, self.blocks, self.layers = (
|
self.dropout = args["dropout"]
|
||||||
args["dropout"],
|
self.blocks = args["blocks"]
|
||||||
args["blocks"],
|
self.layers = args["layers"]
|
||||||
args["layers"],
|
self.do_graph_conv = args.get("do_graph_conv", True)
|
||||||
)
|
self.cat_feat_gc = args.get("cat_feat_gc", False)
|
||||||
self.gcn_bool, self.addaptadj = args["gcn_bool"], args["addaptadj"]
|
self.addaptadj = args.get("addaptadj", True)
|
||||||
|
supports = None
|
||||||
|
aptinit = args.get("aptinit", None)
|
||||||
|
in_dim = args.get("in_dim")
|
||||||
|
out_dim = args.get("out_dim")
|
||||||
|
residual_channels = args.get("residual_channels")
|
||||||
|
dilation_channels = args.get("dilation_channels")
|
||||||
|
skip_channels = args.get("skip_channels")
|
||||||
|
end_channels = args.get("end_channels")
|
||||||
|
kernel_size = args.get("kernel_size")
|
||||||
|
apt_size = args.get("apt_size", 10)
|
||||||
|
|
||||||
# 初始化各种卷积层和模块
|
|
||||||
self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList()
|
|
||||||
self.residual_convs, self.skip_convs, self.bn, self.gconv = (
|
|
||||||
nn.ModuleList(),
|
|
||||||
nn.ModuleList(),
|
|
||||||
nn.ModuleList(),
|
|
||||||
nn.ModuleList(),
|
|
||||||
)
|
|
||||||
self.start_conv = nn.Conv2d(args["in_dim"], args["residual_channels"], 1)
|
|
||||||
self.supports = args.get("supports", None)
|
|
||||||
|
|
||||||
# 计算感受野
|
if self.cat_feat_gc:
|
||||||
|
self.start_conv = nn.Conv2d(in_channels=1, # hard code to avoid errors
|
||||||
|
out_channels=residual_channels,
|
||||||
|
kernel_size=(1, 1))
|
||||||
|
self.cat_feature_conv = nn.Conv2d(in_channels=in_dim - 1,
|
||||||
|
out_channels=residual_channels,
|
||||||
|
kernel_size=(1, 1))
|
||||||
|
else:
|
||||||
|
self.start_conv = nn.Conv2d(in_channels=in_dim,
|
||||||
|
out_channels=residual_channels,
|
||||||
|
kernel_size=(1, 1))
|
||||||
|
|
||||||
|
self.fixed_supports = supports or []
|
||||||
receptive_field = 1
|
receptive_field = 1
|
||||||
self.supports_len = len(self.supports) if self.supports is not None else 0
|
|
||||||
|
|
||||||
# 如果使用自适应邻接矩阵,初始化相关参数
|
self.supports_len = len(self.fixed_supports)
|
||||||
if self.gcn_bool and self.addaptadj:
|
if self.do_graph_conv and self.addaptadj:
|
||||||
aptinit = args.get("aptinit", None)
|
|
||||||
if aptinit is None:
|
if aptinit is None:
|
||||||
if self.supports is None:
|
nodevecs = torch.randn(args["num_nodes"], apt_size), torch.randn(apt_size, args["num_nodes"])
|
||||||
self.supports = []
|
|
||||||
self.nodevec1 = nn.Parameter(
|
|
||||||
torch.randn(args["num_nodes"], 10, device=args["device"])
|
|
||||||
)
|
|
||||||
self.nodevec2 = nn.Parameter(
|
|
||||||
torch.randn(10, args["num_nodes"], device=args["device"])
|
|
||||||
)
|
|
||||||
self.supports_len += 1
|
|
||||||
else:
|
else:
|
||||||
if self.supports is None:
|
nodevecs = self.svd_init(args["num_nodes"], apt_size, aptinit)
|
||||||
self.supports = []
|
self.supports_len += 1
|
||||||
m, p, n = torch.svd(aptinit)
|
self.nodevec1, self.nodevec2 = [Parameter(n.to(args["device"]), requires_grad=True) for n in nodevecs]
|
||||||
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
|
|
||||||
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
|
|
||||||
self.nodevec1 = nn.Parameter(initemb1)
|
|
||||||
self.nodevec2 = nn.Parameter(initemb2)
|
|
||||||
self.supports_len += 1
|
|
||||||
|
|
||||||
# 获取模型参数
|
depth = list(range(self.blocks * self.layers))
|
||||||
ks, res, dil, skip, endc, out_dim = (
|
|
||||||
args["kernel_size"],
|
|
||||||
args["residual_channels"],
|
|
||||||
args["dilation_channels"],
|
|
||||||
args["skip_channels"],
|
|
||||||
args["end_channels"],
|
|
||||||
args["out_dim"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建模型层
|
# 1x1 convolution for residual and skip connections (slightly different see docstring)
|
||||||
|
self.residual_convs = ModuleList([Conv2d(dilation_channels, residual_channels, (1, 1)) for _ in depth])
|
||||||
|
self.skip_convs = ModuleList([Conv2d(dilation_channels, skip_channels, (1, 1)) for _ in depth])
|
||||||
|
self.bn = ModuleList([BatchNorm2d(residual_channels) for _ in depth])
|
||||||
|
self.graph_convs = ModuleList([GraphConvNet(dilation_channels, residual_channels, self.dropout, support_len=self.supports_len)
|
||||||
|
for _ in depth])
|
||||||
|
|
||||||
|
self.filter_convs = ModuleList()
|
||||||
|
self.gate_convs = ModuleList()
|
||||||
for b in range(self.blocks):
|
for b in range(self.blocks):
|
||||||
add_scope, new_dil = ks - 1, 1
|
additional_scope = kernel_size - 1
|
||||||
|
D = 1 # dilation
|
||||||
for i in range(self.layers):
|
for i in range(self.layers):
|
||||||
# 添加时间卷积层
|
# dilated convolutions
|
||||||
self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
|
self.filter_convs.append(Conv2d(residual_channels, dilation_channels, (1, kernel_size), dilation=D))
|
||||||
self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
|
self.gate_convs.append(Conv2d(residual_channels, dilation_channels, (1, kernel_size), dilation=D))
|
||||||
self.residual_convs.append(nn.Conv2d(dil, res, 1))
|
D *= 2
|
||||||
self.skip_convs.append(nn.Conv2d(dil, skip, 1))
|
receptive_field += additional_scope
|
||||||
self.bn.append(nn.BatchNorm2d(res))
|
additional_scope *= 2
|
||||||
new_dil *= 2
|
|
||||||
receptive_field += add_scope
|
|
||||||
add_scope *= 2
|
|
||||||
if self.gcn_bool:
|
|
||||||
self.gconv.append(
|
|
||||||
gcn(dil, res, args["dropout"], support_len=self.supports_len)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 输出层
|
|
||||||
self.end_conv_1 = nn.Conv2d(skip, endc, 1)
|
|
||||||
self.end_conv_2 = nn.Conv2d(endc, out_dim, 1)
|
|
||||||
self.receptive_field = receptive_field
|
self.receptive_field = receptive_field
|
||||||
|
|
||||||
|
self.end_conv_1 = Conv2d(skip_channels, end_channels, (1, 1), bias=True)
|
||||||
|
self.end_conv_2 = Conv2d(end_channels, out_dim, (1, 1), bias=True)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
"""
|
x = input[..., 0:1].transpose(1, 3)
|
||||||
前向传播函数
|
# Input shape is (bs, features, n_nodes, n_timesteps)
|
||||||
实现模型的推理过程
|
in_len = x.size(3)
|
||||||
"""
|
if in_len < self.receptive_field:
|
||||||
# 数据预处理
|
x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0))
|
||||||
input = input[..., 0:2].transpose(1, 3)
|
if self.cat_feat_gc:
|
||||||
input = F.pad(input, (1, 0, 0, 0))
|
f1, f2 = x[:, [0]], x[:, 1:]
|
||||||
in_len = input.size(3)
|
x1 = self.start_conv(f1)
|
||||||
x = (
|
x2 = F.leaky_relu(self.cat_feature_conv(f2))
|
||||||
F.pad(input, (self.receptive_field - in_len, 0, 0, 0))
|
x = x1 + x2
|
||||||
if in_len < self.receptive_field
|
else:
|
||||||
else input
|
x = self.start_conv(x)
|
||||||
)
|
skip = 0
|
||||||
|
adjacency_matrices = self.fixed_supports
|
||||||
# 初始卷积
|
# calculate the current adaptive adj matrix once per iteration
|
||||||
x, skip, new_supports = self.start_conv(x), 0, None
|
if self.addaptadj:
|
||||||
|
|
||||||
# 如果使用自适应邻接矩阵,计算新的邻接矩阵
|
|
||||||
if self.gcn_bool and self.addaptadj and self.supports is not None:
|
|
||||||
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
|
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
|
||||||
new_supports = self.supports + [adp]
|
adjacency_matrices = self.fixed_supports + [adp]
|
||||||
|
|
||||||
# 主网络层的前向传播
|
# WaveNet layers
|
||||||
for i in range(self.blocks * self.layers):
|
for i in range(self.blocks * self.layers):
|
||||||
residual = x
|
residual = x
|
||||||
# 时间卷积操作
|
# dilated convolution
|
||||||
f = self.filter_convs[i](residual).tanh()
|
filter = torch.tanh(self.filter_convs[i](residual))
|
||||||
g = self.gate_convs[i](residual).sigmoid()
|
gate = torch.sigmoid(self.gate_convs[i](residual))
|
||||||
x = f * g
|
x = filter * gate
|
||||||
s = self.skip_convs[i](x)
|
# parametrized skip connection
|
||||||
skip = (
|
s = self.skip_convs[i](x) # what are we skipping??
|
||||||
skip[:, :, :, -s.size(3) :] if isinstance(skip, torch.Tensor) else 0
|
try: # if i > 0 this works
|
||||||
) + s
|
skip = skip[:, :, :, -s.size(3):] # TODO(SS): Mean/Max Pool?
|
||||||
|
except:
|
||||||
|
skip = 0
|
||||||
|
skip = s + skip
|
||||||
|
if i == (self.blocks * self.layers - 1): # last X getting ignored anyway
|
||||||
|
break
|
||||||
|
|
||||||
# 图卷积操作
|
if self.do_graph_conv:
|
||||||
if self.gcn_bool and self.supports is not None:
|
graph_out = self.graph_convs[i](x, adjacency_matrices)
|
||||||
x = self.gconv[i](x, new_supports if self.addaptadj else self.supports)
|
x = x + graph_out if self.cat_feat_gc else graph_out
|
||||||
else:
|
else:
|
||||||
x = self.residual_convs[i](x)
|
x = self.residual_convs[i](x)
|
||||||
x = x + residual[:, :, :, -x.size(3) :]
|
x = x + residual[:, :, :, -x.size(3):] # TODO(SS): Mean/Max Pool?
|
||||||
x = self.bn[i](x)
|
x = self.bn[i](x)
|
||||||
|
|
||||||
# 输出层处理
|
x = F.relu(skip) # ignore last X?
|
||||||
return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip))))
|
x = F.relu(self.end_conv_1(x))
|
||||||
|
x = self.end_conv_2(x) # downsample to (bs, seq_length, 207, nfeatures)
|
||||||
|
# x = x.transpose(1, 3)
|
||||||
|
return x
|
||||||
|
|
|
||||||
|
|
@ -1,97 +1,98 @@
|
||||||
import torch
|
import torch, torch.nn as nn, torch.nn.functional as F
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.autograd import Variable
|
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
class nconv(nn.Module):
|
class nconv(nn.Module):
|
||||||
def __init__(self):
|
"""
|
||||||
super(nconv, self).__init__()
|
图卷积操作的实现类
|
||||||
|
使用einsum进行矩阵运算,实现图卷积操作
|
||||||
|
"""
|
||||||
|
|
||||||
def forward(self, x, A):
|
def forward(self, x, A):
|
||||||
x = torch.einsum("ncvl,vw->ncwl", (x, A))
|
return torch.einsum("ncvl,vw->ncwl", (x, A)).contiguous()
|
||||||
return x.contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
class linear(nn.Module):
|
class linear(nn.Module):
|
||||||
|
"""
|
||||||
|
线性变换层
|
||||||
|
使用1x1卷积实现线性变换
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, c_in, c_out):
|
def __init__(self, c_in, c_out):
|
||||||
super(linear, self).__init__()
|
super().__init__()
|
||||||
self.mlp = torch.nn.Conv2d(
|
self.mlp = nn.Conv2d(c_in, c_out, 1)
|
||||||
c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.mlp(x)
|
return self.mlp(x)
|
||||||
|
|
||||||
|
|
||||||
class gcn(nn.Module):
|
class gcn(nn.Module):
|
||||||
|
"""
|
||||||
|
图卷积网络层
|
||||||
|
实现高阶图卷积操作,支持多阶邻接矩阵
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
|
def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
|
||||||
super(gcn, self).__init__()
|
super().__init__()
|
||||||
self.nconv = nconv()
|
self.nconv = nconv()
|
||||||
c_in = (order * support_len + 1) * c_in
|
c_in = (order * support_len + 1) * c_in
|
||||||
self.mlp = linear(c_in, c_out)
|
self.mlp, self.dropout, self.order = linear(c_in, c_out), dropout, order
|
||||||
self.dropout = dropout
|
|
||||||
self.order = order
|
|
||||||
|
|
||||||
def forward(self, x, support):
|
def forward(self, x, support):
|
||||||
out = [x]
|
out = [x]
|
||||||
for a in support:
|
for a in support:
|
||||||
x1 = self.nconv(x, a)
|
x1 = self.nconv(x, a)
|
||||||
out.append(x1)
|
out.append(x1)
|
||||||
for k in range(2, self.order + 1):
|
for _ in range(2, self.order + 1):
|
||||||
x2 = self.nconv(x1, a)
|
x1 = self.nconv(x1, a)
|
||||||
out.append(x2)
|
out.append(x1)
|
||||||
x1 = x2
|
return F.dropout(
|
||||||
|
self.mlp(torch.cat(out, dim=1)), self.dropout, training=self.training
|
||||||
h = torch.cat(out, dim=1)
|
)
|
||||||
h = self.mlp(h)
|
|
||||||
h = F.dropout(h, self.dropout, training=self.training)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class gwnet(nn.Module):
|
class gwnet(nn.Module):
|
||||||
|
"""
|
||||||
|
Graph WaveNet模型的主类
|
||||||
|
结合了图卷积网络和时序卷积网络,用于时空预测任务
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super(gwnet, self).__init__()
|
super().__init__()
|
||||||
self.dropout = args["dropout"]
|
# 初始化基本参数
|
||||||
self.blocks = args["blocks"]
|
self.dropout, self.blocks, self.layers = (
|
||||||
self.layers = args["layers"]
|
args["dropout"],
|
||||||
self.gcn_bool = args["gcn_bool"]
|
args["blocks"],
|
||||||
self.addaptadj = args["addaptadj"]
|
args["layers"],
|
||||||
|
|
||||||
self.filter_convs = nn.ModuleList()
|
|
||||||
self.gate_convs = nn.ModuleList()
|
|
||||||
self.residual_convs = nn.ModuleList()
|
|
||||||
self.skip_convs = nn.ModuleList()
|
|
||||||
self.bn = nn.ModuleList()
|
|
||||||
self.gconv = nn.ModuleList()
|
|
||||||
|
|
||||||
self.start_conv = nn.Conv2d(
|
|
||||||
in_channels=args["in_dim"],
|
|
||||||
out_channels=args["residual_channels"],
|
|
||||||
kernel_size=(1, 1),
|
|
||||||
)
|
)
|
||||||
|
self.gcn_bool, self.addaptadj = args["gcn_bool"], args["addaptadj"]
|
||||||
|
|
||||||
|
# 初始化各种卷积层和模块
|
||||||
|
self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList()
|
||||||
|
self.residual_convs, self.skip_convs, self.bn, self.gconv = (
|
||||||
|
nn.ModuleList(),
|
||||||
|
nn.ModuleList(),
|
||||||
|
nn.ModuleList(),
|
||||||
|
nn.ModuleList(),
|
||||||
|
)
|
||||||
|
self.start_conv = nn.Conv2d(args["in_dim"], args["residual_channels"], 1)
|
||||||
self.supports = args.get("supports", None)
|
self.supports = args.get("supports", None)
|
||||||
|
|
||||||
|
# 计算感受野
|
||||||
receptive_field = 1
|
receptive_field = 1
|
||||||
|
self.supports_len = len(self.supports) if self.supports is not None else 0
|
||||||
|
|
||||||
self.supports_len = 0
|
# 如果使用自适应邻接矩阵,初始化相关参数
|
||||||
if self.supports is not None:
|
|
||||||
self.supports_len += len(self.supports)
|
|
||||||
|
|
||||||
if self.gcn_bool and self.addaptadj:
|
if self.gcn_bool and self.addaptadj:
|
||||||
aptinit = args.get("aptinit", None)
|
aptinit = args.get("aptinit", None)
|
||||||
if aptinit is None:
|
if aptinit is None:
|
||||||
if self.supports is None:
|
if self.supports is None:
|
||||||
self.supports = []
|
self.supports = []
|
||||||
self.nodevec1 = nn.Parameter(
|
self.nodevec1 = nn.Parameter(
|
||||||
torch.randn(args["num_nodes"], 10).to(args["device"]),
|
torch.randn(args["num_nodes"], 10, device=args["device"])
|
||||||
requires_grad=True,
|
)
|
||||||
).to(args["device"])
|
|
||||||
self.nodevec2 = nn.Parameter(
|
self.nodevec2 = nn.Parameter(
|
||||||
torch.randn(10, args["num_nodes"]).to(args["device"]),
|
torch.randn(10, args["num_nodes"], device=args["device"])
|
||||||
requires_grad=True,
|
)
|
||||||
).to(args["device"])
|
|
||||||
self.supports_len += 1
|
self.supports_len += 1
|
||||||
else:
|
else:
|
||||||
if self.supports is None:
|
if self.supports is None:
|
||||||
|
|
@ -99,156 +100,85 @@ class gwnet(nn.Module):
|
||||||
m, p, n = torch.svd(aptinit)
|
m, p, n = torch.svd(aptinit)
|
||||||
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
|
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
|
||||||
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
|
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
|
||||||
self.nodevec1 = nn.Parameter(initemb1, requires_grad=True).to(
|
self.nodevec1 = nn.Parameter(initemb1)
|
||||||
args["device"]
|
self.nodevec2 = nn.Parameter(initemb2)
|
||||||
)
|
|
||||||
self.nodevec2 = nn.Parameter(initemb2, requires_grad=True).to(
|
|
||||||
args["device"]
|
|
||||||
)
|
|
||||||
self.supports_len += 1
|
self.supports_len += 1
|
||||||
|
|
||||||
kernel_size = args["kernel_size"]
|
# 获取模型参数
|
||||||
residual_channels = args["residual_channels"]
|
ks, res, dil, skip, endc, out_dim = (
|
||||||
dilation_channels = args["dilation_channels"]
|
args["kernel_size"],
|
||||||
kernel_size = args["kernel_size"]
|
args["residual_channels"],
|
||||||
skip_channels = args["skip_channels"]
|
args["dilation_channels"],
|
||||||
end_channels = args["end_channels"]
|
args["skip_channels"],
|
||||||
out_dim = args["out_dim"]
|
args["end_channels"],
|
||||||
dropout = args["dropout"]
|
args["out_dim"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建模型层
|
||||||
for b in range(self.blocks):
|
for b in range(self.blocks):
|
||||||
additional_scope = kernel_size - 1
|
add_scope, new_dil = ks - 1, 1
|
||||||
new_dilation = 1
|
|
||||||
for i in range(self.layers):
|
for i in range(self.layers):
|
||||||
# dilated convolutions
|
# 添加时间卷积层
|
||||||
self.filter_convs.append(
|
self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
|
||||||
nn.Conv2d(
|
self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
|
||||||
in_channels=residual_channels,
|
self.residual_convs.append(nn.Conv2d(dil, res, 1))
|
||||||
out_channels=dilation_channels,
|
self.skip_convs.append(nn.Conv2d(dil, skip, 1))
|
||||||
kernel_size=(1, kernel_size),
|
self.bn.append(nn.BatchNorm2d(res))
|
||||||
dilation=new_dilation,
|
new_dil *= 2
|
||||||
)
|
receptive_field += add_scope
|
||||||
)
|
add_scope *= 2
|
||||||
|
|
||||||
self.gate_convs.append(
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=residual_channels,
|
|
||||||
out_channels=dilation_channels,
|
|
||||||
kernel_size=(1, kernel_size),
|
|
||||||
dilation=new_dilation,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1x1 convolution for residual connection
|
|
||||||
self.residual_convs.append(
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=dilation_channels,
|
|
||||||
out_channels=residual_channels,
|
|
||||||
kernel_size=(1, 1),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1x1 convolution for skip connection
|
|
||||||
self.skip_convs.append(
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=dilation_channels,
|
|
||||||
out_channels=skip_channels,
|
|
||||||
kernel_size=(1, 1),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.bn.append(nn.BatchNorm2d(residual_channels))
|
|
||||||
new_dilation *= 2
|
|
||||||
receptive_field += additional_scope
|
|
||||||
additional_scope *= 2
|
|
||||||
if self.gcn_bool:
|
if self.gcn_bool:
|
||||||
self.gconv.append(
|
self.gconv.append(
|
||||||
gcn(
|
gcn(dil, res, args["dropout"], support_len=self.supports_len)
|
||||||
dilation_channels,
|
|
||||||
residual_channels,
|
|
||||||
dropout,
|
|
||||||
support_len=self.supports_len,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.end_conv_1 = nn.Conv2d(
|
# 输出层
|
||||||
in_channels=skip_channels,
|
self.end_conv_1 = nn.Conv2d(skip, endc, 1)
|
||||||
out_channels=end_channels,
|
self.end_conv_2 = nn.Conv2d(endc, out_dim, 1)
|
||||||
kernel_size=(1, 1),
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.end_conv_2 = nn.Conv2d(
|
|
||||||
in_channels=end_channels,
|
|
||||||
out_channels=out_dim,
|
|
||||||
kernel_size=(1, 1),
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.receptive_field = receptive_field
|
self.receptive_field = receptive_field
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
input = input[..., 0:2]
|
"""
|
||||||
input = input.transpose(1, 3)
|
前向传播函数
|
||||||
input = nn.functional.pad(input, (1, 0, 0, 0))
|
实现模型的推理过程
|
||||||
|
"""
|
||||||
|
# 数据预处理
|
||||||
|
input = input[..., 0:2].transpose(1, 3)
|
||||||
|
input = F.pad(input, (1, 0, 0, 0))
|
||||||
in_len = input.size(3)
|
in_len = input.size(3)
|
||||||
if in_len < self.receptive_field:
|
x = (
|
||||||
x = nn.functional.pad(input, (self.receptive_field - in_len, 0, 0, 0))
|
F.pad(input, (self.receptive_field - in_len, 0, 0, 0))
|
||||||
else:
|
if in_len < self.receptive_field
|
||||||
x = input
|
else input
|
||||||
x = self.start_conv(x)
|
)
|
||||||
skip = 0
|
|
||||||
|
|
||||||
# calculate the current adaptive adj matrix once per iteration
|
# 初始卷积
|
||||||
new_supports = None
|
x, skip, new_supports = self.start_conv(x), 0, None
|
||||||
|
|
||||||
|
# 如果使用自适应邻接矩阵,计算新的邻接矩阵
|
||||||
if self.gcn_bool and self.addaptadj and self.supports is not None:
|
if self.gcn_bool and self.addaptadj and self.supports is not None:
|
||||||
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
|
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
|
||||||
new_supports = self.supports + [adp]
|
new_supports = self.supports + [adp]
|
||||||
|
|
||||||
# WaveNet layers
|
# 主网络层的前向传播
|
||||||
for i in range(self.blocks * self.layers):
|
for i in range(self.blocks * self.layers):
|
||||||
# |----------------------------------------| *residual*
|
|
||||||
# | |
|
|
||||||
# | |-- conv -- tanh --| |
|
|
||||||
# -> dilate -|----| * ----|-- 1x1 -- + --> *input*
|
|
||||||
# |-- conv -- sigm --| |
|
|
||||||
# 1x1
|
|
||||||
# |
|
|
||||||
# ---------------------------------------> + -------------> *skip*
|
|
||||||
|
|
||||||
# (dilation, init_dilation) = self.dilations[i]
|
|
||||||
|
|
||||||
# residual = dilation_func(x, dilation, init_dilation, i)
|
|
||||||
residual = x
|
residual = x
|
||||||
# dilated convolution
|
# 时间卷积操作
|
||||||
filter = self.filter_convs[i](residual)
|
f = self.filter_convs[i](residual).tanh()
|
||||||
filter = torch.tanh(filter)
|
g = self.gate_convs[i](residual).sigmoid()
|
||||||
gate = self.gate_convs[i](residual)
|
x = f * g
|
||||||
gate = torch.sigmoid(gate)
|
s = self.skip_convs[i](x)
|
||||||
x = filter * gate
|
skip = (
|
||||||
|
skip[:, :, :, -s.size(3) :] if isinstance(skip, torch.Tensor) else 0
|
||||||
# parametrized skip connection
|
) + s
|
||||||
|
|
||||||
s = x
|
|
||||||
s = self.skip_convs[i](s)
|
|
||||||
try:
|
|
||||||
skip = skip[:, :, :, -s.size(3) :]
|
|
||||||
except:
|
|
||||||
skip = 0
|
|
||||||
skip = s + skip
|
|
||||||
|
|
||||||
|
# 图卷积操作
|
||||||
if self.gcn_bool and self.supports is not None:
|
if self.gcn_bool and self.supports is not None:
|
||||||
if self.addaptadj:
|
x = self.gconv[i](x, new_supports if self.addaptadj else self.supports)
|
||||||
x = self.gconv[i](x, new_supports)
|
|
||||||
else:
|
|
||||||
x = self.gconv[i](x, self.supports)
|
|
||||||
else:
|
else:
|
||||||
x = self.residual_convs[i](x)
|
x = self.residual_convs[i](x)
|
||||||
|
|
||||||
x = x + residual[:, :, :, -x.size(3) :]
|
x = x + residual[:, :, :, -x.size(3) :]
|
||||||
|
|
||||||
x = self.bn[i](x)
|
x = self.bn[i](x)
|
||||||
|
|
||||||
x = F.relu(skip)
|
# 输出层处理
|
||||||
x = F.relu(self.end_conv_1(x))
|
return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip))))
|
||||||
x = self.end_conv_2(x)
|
|
||||||
return x
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,95 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# 设置默认模型名和数据集列表
|
||||||
|
MODEL_NAME="GWN"
|
||||||
|
DATASETS=(
|
||||||
|
"METR-LA"
|
||||||
|
"PEMS-BAY"
|
||||||
|
"NYCBike-InFlow"
|
||||||
|
"NYCBike-OutFlow"
|
||||||
|
"AirQuality"
|
||||||
|
"SolarEnergy"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化统计变量
|
||||||
|
success_count=0
|
||||||
|
failure_count=0
|
||||||
|
missing_count=0
|
||||||
|
total_count=0
|
||||||
|
success_datasets=()
|
||||||
|
failure_datasets=()
|
||||||
|
missing_datasets=()
|
||||||
|
|
||||||
|
# 检查是否有参数传入来覆盖默认值
|
||||||
|
if [ $# -gt 0 ]; then
|
||||||
|
MODEL_NAME=$1
|
||||||
|
# 如果传入了更多参数,使用它们作为数据集列表
|
||||||
|
if [ $# -gt 1 ]; then
|
||||||
|
DATASETS=(${@:2})
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "使用模型: $MODEL_NAME"
|
||||||
|
echo "数据集列表: ${DATASETS[*]}"
|
||||||
|
echo "开始测试..."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# 循环测试每个数据集
|
||||||
|
for dataset in "${DATASETS[@]}"; do
|
||||||
|
total_count=$((total_count + 1))
|
||||||
|
# 构建配置文件路径
|
||||||
|
CONFIG_PATH="config/${MODEL_NAME}/${dataset}.yaml"
|
||||||
|
|
||||||
|
echo "测试数据集: $dataset"
|
||||||
|
echo "使用配置文件: $CONFIG_PATH"
|
||||||
|
|
||||||
|
# 检查配置文件是否存在
|
||||||
|
if [ ! -f "$CONFIG_PATH" ]; then
|
||||||
|
echo "错误: 配置文件 $CONFIG_PATH 不存在!"
|
||||||
|
missing_count=$((missing_count + 1))
|
||||||
|
missing_datasets+=("$dataset")
|
||||||
|
echo "----------------------------------------"
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 执行测试命令并捕获输出
|
||||||
|
echo "执行: python run.py --config $CONFIG_PATH"
|
||||||
|
output=$(python run.py --config "$CONFIG_PATH" 2>&1)
|
||||||
|
|
||||||
|
# 如果没有找到明确的标记,回退到检查退出码
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
echo "数据集 $dataset 测试成功! (基于退出码)"
|
||||||
|
success_count=$((success_count + 1))
|
||||||
|
success_datasets+=("$dataset")
|
||||||
|
else
|
||||||
|
echo "数据集 $dataset 测试失败! (基于退出码)"
|
||||||
|
failure_count=$((failure_count + 1))
|
||||||
|
failure_datasets+=("$dataset")
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "----------------------------------------"
|
||||||
|
done
|
||||||
|
|
||||||
|
# 输出总结
|
||||||
|
echo "======================================="
|
||||||
|
echo "测试总结"
|
||||||
|
echo "======================================="
|
||||||
|
echo "总数据集数量: $total_count"
|
||||||
|
echo "成功数量: $success_count"
|
||||||
|
echo "失败数量: $failure_count"
|
||||||
|
echo "缺失配置文件数量: $missing_count"
|
||||||
|
|
||||||
|
if [ ${#success_datasets[@]} -gt 0 ]; then
|
||||||
|
echo "成功的数据集: ${success_datasets[*]}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${#failure_datasets[@]} -gt 0 ]; then
|
||||||
|
echo "失败的数据集: ${failure_datasets[*]}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${#missing_datasets[@]} -gt 0 ]; then
|
||||||
|
echo "缺失配置的数据集: ${missing_datasets[*]}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "======================================="
|
||||||
|
echo "所有测试完成!"
|
||||||
|
|
@ -177,6 +177,14 @@ class Trainer:
|
||||||
# 前向传播
|
# 前向传播
|
||||||
label = target[..., : self.args["output_dim"]]
|
label = target[..., : self.args["output_dim"]]
|
||||||
output = self.model(data).to(self.device)
|
output = self.model(data).to(self.device)
|
||||||
|
# if output.shape != label.shape:
|
||||||
|
# import sys
|
||||||
|
# print(f"[Wrong]: Output shape: {output.shape}, Label shape: {label.shape}")
|
||||||
|
# sys.exit(1)
|
||||||
|
# else:
|
||||||
|
# import sys
|
||||||
|
# print(f"[Right]: Output shape: {output.shape}, Label shape: {label.shape}")
|
||||||
|
# sys.exit(0)
|
||||||
loss = self.loss(output, label)
|
loss = self.loss(output, label)
|
||||||
|
|
||||||
# 反归一化
|
# 反归一化
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue