适配GraphWaveNet
This commit is contained in:
parent
140ead3975
commit
a9313390ac
|
|
@ -6,40 +6,41 @@ basic:
|
|||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 16
|
||||
batch_size: 64
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 6
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 12
|
||||
num_nodes: 35
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
addaptadj: true
|
||||
apt_size: 10
|
||||
aptinit: null
|
||||
batch_size: 16
|
||||
batch_size: 64
|
||||
blocks: 4
|
||||
dilation_channels: 32
|
||||
dropout: 0.3
|
||||
do_graph_conv: True
|
||||
end_channels: 512
|
||||
gcn_bool: true
|
||||
in_dim: 2
|
||||
input_dim: 6
|
||||
in_dim: 1
|
||||
input_dim: 1
|
||||
kernel_size: 2
|
||||
layers: 2
|
||||
num_nodes: 12
|
||||
out_dim: 12
|
||||
output_dim: 6
|
||||
layers: 4
|
||||
num_nodes: 35
|
||||
out_dim: 24
|
||||
residual_channels: 32
|
||||
skip_channels: 256
|
||||
supports: null
|
||||
|
||||
train:
|
||||
batch_size: 16
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
|
|
@ -54,7 +55,7 @@ train:
|
|||
mae_thresh: 0.0
|
||||
mape_thresh: 0.0
|
||||
max_grad_norm: 5
|
||||
output_dim: 6
|
||||
output_dim: 1
|
||||
plot: false
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
|
|||
|
|
@ -20,24 +20,26 @@ data:
|
|||
|
||||
model:
|
||||
addaptadj: true
|
||||
apt_size: 10
|
||||
aptinit: null
|
||||
batch_size: 32
|
||||
batch_size: 16
|
||||
blocks: 4
|
||||
dilation_channels: 32
|
||||
dropout: 0.3
|
||||
do_graph_conv: True
|
||||
end_channels: 512
|
||||
gcn_bool: true
|
||||
in_dim: 2
|
||||
in_dim: 1
|
||||
input_dim: 1
|
||||
kernel_size: 2
|
||||
layers: 2
|
||||
layers: 4
|
||||
num_nodes: 1024
|
||||
out_dim: 12
|
||||
output_dim: 1
|
||||
out_dim: 24
|
||||
residual_channels: 32
|
||||
skip_channels: 256
|
||||
supports: null
|
||||
|
||||
|
||||
train:
|
||||
batch_size: 32
|
||||
debug: false
|
||||
|
|
|
|||
|
|
@ -20,20 +20,21 @@ data:
|
|||
|
||||
model:
|
||||
addaptadj: true
|
||||
apt_size: 10
|
||||
aptinit: null
|
||||
batch_size: 32
|
||||
blocks: 4
|
||||
dilation_channels: 32
|
||||
dropout: 0.3
|
||||
do_graph_conv: True
|
||||
end_channels: 512
|
||||
gcn_bool: true
|
||||
in_dim: 2
|
||||
in_dim: 1
|
||||
input_dim: 1
|
||||
kernel_size: 2
|
||||
layers: 2
|
||||
layers: 4
|
||||
num_nodes: 1024
|
||||
out_dim: 12
|
||||
output_dim: 1
|
||||
out_dim: 24
|
||||
residual_channels: 32
|
||||
skip_channels: 256
|
||||
supports: null
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ basic:
|
|||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 16
|
||||
batch_size: 64
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
|
|
@ -20,26 +20,27 @@ data:
|
|||
|
||||
model:
|
||||
addaptadj: true
|
||||
apt_size: 10
|
||||
aptinit: null
|
||||
batch_size: 16
|
||||
batch_size: 64
|
||||
blocks: 4
|
||||
dilation_channels: 32
|
||||
dropout: 0.3
|
||||
do_graph_conv: True
|
||||
end_channels: 512
|
||||
gcn_bool: true
|
||||
in_dim: 2
|
||||
in_dim: 1
|
||||
input_dim: 1
|
||||
kernel_size: 2
|
||||
layers: 2
|
||||
layers: 4
|
||||
num_nodes: 207
|
||||
out_dim: 12
|
||||
output_dim: 1
|
||||
out_dim: 24
|
||||
residual_channels: 32
|
||||
skip_channels: 256
|
||||
supports: null
|
||||
|
||||
train:
|
||||
batch_size: 16
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
|
|
|
|||
|
|
@ -20,20 +20,21 @@ data:
|
|||
|
||||
model:
|
||||
addaptadj: true
|
||||
apt_size: 10
|
||||
aptinit: null
|
||||
batch_size: 32
|
||||
blocks: 4
|
||||
dilation_channels: 32
|
||||
dropout: 0.3
|
||||
do_graph_conv: True
|
||||
end_channels: 512
|
||||
gcn_bool: true
|
||||
in_dim: 2
|
||||
in_dim: 1
|
||||
input_dim: 1
|
||||
kernel_size: 2
|
||||
layers: 2
|
||||
layers: 4
|
||||
num_nodes: 128
|
||||
out_dim: 12
|
||||
output_dim: 1
|
||||
out_dim: 24
|
||||
residual_channels: 32
|
||||
skip_channels: 256
|
||||
supports: null
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ basic:
|
|||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 32
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
|
|
@ -20,26 +20,27 @@ data:
|
|||
|
||||
model:
|
||||
addaptadj: true
|
||||
apt_size: 10
|
||||
aptinit: null
|
||||
batch_size: 32
|
||||
batch_size: 16
|
||||
blocks: 4
|
||||
dilation_channels: 32
|
||||
dropout: 0.3
|
||||
do_graph_conv: True
|
||||
end_channels: 512
|
||||
gcn_bool: true
|
||||
in_dim: 2
|
||||
in_dim: 1
|
||||
input_dim: 1
|
||||
kernel_size: 2
|
||||
layers: 2
|
||||
layers: 4
|
||||
num_nodes: 128
|
||||
out_dim: 12
|
||||
output_dim: 1
|
||||
out_dim: 24
|
||||
residual_channels: 32
|
||||
skip_channels: 256
|
||||
supports: null
|
||||
|
||||
train:
|
||||
batch_size: 32
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
basic:
|
||||
dataset: PEMS-BAY
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: GWN
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 64
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 325
|
||||
steps_per_day: 288
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
addaptadj: true
|
||||
apt_size: 10
|
||||
aptinit: null
|
||||
batch_size: 64
|
||||
blocks: 4
|
||||
dilation_channels: 32
|
||||
dropout: 0.3
|
||||
do_graph_conv: True
|
||||
end_channels: 512
|
||||
gcn_bool: true
|
||||
in_dim: 1
|
||||
input_dim: 1
|
||||
kernel_size: 2
|
||||
layers: 4
|
||||
num_nodes: 325
|
||||
out_dim: 24
|
||||
residual_channels: 32
|
||||
skip_channels: 256
|
||||
supports: null
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 300
|
||||
grad_norm: false
|
||||
log_step: 1000
|
||||
loss_func: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: 0.0
|
||||
mape_thresh: 0.0
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
plot: false
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -27,7 +27,7 @@ model:
|
|||
dropout: 0.3
|
||||
end_channels: 512
|
||||
gcn_bool: true
|
||||
in_dim: 2
|
||||
in_dim: 3
|
||||
input_dim: 1
|
||||
kernel_size: 2
|
||||
layers: 2
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ model:
|
|||
dropout: 0.3
|
||||
end_channels: 512
|
||||
gcn_bool: true
|
||||
in_dim: 2
|
||||
in_dim: 1
|
||||
input_dim: 1
|
||||
kernel_size: 2
|
||||
layers: 2
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ model:
|
|||
dropout: 0.3
|
||||
end_channels: 512
|
||||
gcn_bool: true
|
||||
in_dim: 2
|
||||
in_dim: 3
|
||||
input_dim: 1
|
||||
kernel_size: 2
|
||||
layers: 2
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ model:
|
|||
dropout: 0.3
|
||||
end_channels: 512
|
||||
gcn_bool: true
|
||||
in_dim: 2
|
||||
in_dim: 3
|
||||
input_dim: 1
|
||||
kernel_size: 2
|
||||
layers: 2
|
||||
|
|
|
|||
|
|
@ -20,20 +20,21 @@ data:
|
|||
|
||||
model:
|
||||
addaptadj: true
|
||||
apt_size: 10
|
||||
aptinit: null
|
||||
batch_size: 64
|
||||
batch_size: 32
|
||||
blocks: 4
|
||||
dilation_channels: 32
|
||||
dropout: 0.3
|
||||
do_graph_conv: True
|
||||
end_channels: 512
|
||||
gcn_bool: true
|
||||
in_dim: 2
|
||||
in_dim: 1
|
||||
input_dim: 1
|
||||
kernel_size: 2
|
||||
layers: 2
|
||||
layers: 4
|
||||
num_nodes: 137
|
||||
out_dim: 12
|
||||
output_dim: 1
|
||||
out_dim: 24
|
||||
residual_channels: 32
|
||||
skip_channels: 256
|
||||
supports: null
|
||||
|
|
|
|||
234
config/tmp.py
234
config/tmp.py
|
|
@ -1,234 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from ruamel.yaml import YAML
|
||||
from ruamel.yaml.comments import CommentedMap
|
||||
|
||||
yaml = YAML()
|
||||
yaml.preserve_quotes = True
|
||||
yaml.indent(mapping=2, sequence=4, offset=2)
|
||||
|
||||
# 允许的 data keys
|
||||
DATA_ALLOWED_KEYS = {
|
||||
"lag",
|
||||
"horizon",
|
||||
"num_nodes",
|
||||
"steps_per_day",
|
||||
"days_per_week",
|
||||
"test_ratio",
|
||||
"val_ratio",
|
||||
"batch_size",
|
||||
"input_dim",
|
||||
"column_wise",
|
||||
"normalizer",
|
||||
}
|
||||
|
||||
# 全局默认值
|
||||
GLOBAL_DEFAULTS = {
|
||||
"lag": 24,
|
||||
"horizon": 24,
|
||||
"num_nodes": 1,
|
||||
"steps_per_day": 24,
|
||||
"days_per_week": 7,
|
||||
"test_ratio": 0.2,
|
||||
"val_ratio": 0.2,
|
||||
"batch_size": 16,
|
||||
"input_dim": 1,
|
||||
"column_wise": False,
|
||||
"normalizer": "std",
|
||||
}
|
||||
|
||||
# train全局默认值
|
||||
GLOBAL_TRAIN_DEFAULTS = {
|
||||
"output_dim": 1
|
||||
}
|
||||
|
||||
|
||||
def load_yaml(path):
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return yaml.load(f)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def collect_dataset_defaults(base="."):
|
||||
"""
|
||||
收集每个数据集 data 的 key 默认值,以及 train.output_dim 默认值
|
||||
"""
|
||||
data_defaults = defaultdict(dict)
|
||||
train_output_defaults = dict()
|
||||
|
||||
for root, _, files in os.walk(base):
|
||||
for name in files:
|
||||
if not (name.endswith(".yaml") or name.endswith(".yml")):
|
||||
continue
|
||||
path = os.path.join(root, name)
|
||||
cm = load_yaml(path)
|
||||
if not isinstance(cm, CommentedMap):
|
||||
continue
|
||||
basic = cm.get("basic")
|
||||
if not isinstance(basic, dict):
|
||||
continue
|
||||
dataset = basic.get("dataset")
|
||||
if dataset is None:
|
||||
continue
|
||||
ds = str(dataset)
|
||||
|
||||
# data 默认值
|
||||
data_sec = cm.get("data")
|
||||
if isinstance(data_sec, dict):
|
||||
for key in DATA_ALLOWED_KEYS:
|
||||
if key not in data_defaults[ds] and key in data_sec and data_sec[key] is not None:
|
||||
data_defaults[ds][key] = data_sec[key]
|
||||
|
||||
# train.output_dim 默认值
|
||||
train_sec = cm.get("train")
|
||||
if isinstance(train_sec, dict):
|
||||
val = train_sec.get("output_dim")
|
||||
if val is not None and ds not in train_output_defaults:
|
||||
train_output_defaults[ds] = val
|
||||
|
||||
return data_defaults, train_output_defaults
|
||||
|
||||
|
||||
def ensure_basic_seed(cm: CommentedMap, path: str):
|
||||
if "basic" not in cm or not isinstance(cm["basic"], dict):
|
||||
cm["basic"] = CommentedMap()
|
||||
basic = cm["basic"]
|
||||
if "seed" not in basic:
|
||||
basic["seed"] = 2023
|
||||
print(f"[ADD] {path}: basic.seed = 2023")
|
||||
|
||||
|
||||
def fill_data_defaults(cm: CommentedMap, data_defaults: dict, path: str):
|
||||
if "data" not in cm or not isinstance(cm["data"], dict):
|
||||
cm["data"] = CommentedMap()
|
||||
data_sec = cm["data"]
|
||||
|
||||
basic = cm.get("basic", {})
|
||||
dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None
|
||||
|
||||
for key in sorted(DATA_ALLOWED_KEYS):
|
||||
if key in data_sec and data_sec[key] is not None:
|
||||
continue
|
||||
if dataset and dataset in data_defaults and key in data_defaults[dataset]:
|
||||
chosen = data_defaults[dataset][key]
|
||||
src = f"default_from_dataset[{dataset}]"
|
||||
else:
|
||||
chosen = GLOBAL_DEFAULTS[key]
|
||||
src = "GLOBAL_DEFAULTS"
|
||||
data_sec[key] = chosen
|
||||
print(f"[FILL] {path}: data.{key} <- {src} ({repr(chosen)})")
|
||||
|
||||
|
||||
def merge_test_log_into_train(cm: CommentedMap, path: str):
|
||||
"""
|
||||
将 test 和 log 的 key 合并到 train,并删除 test 和 log
|
||||
同时确保 train.debug 存在
|
||||
"""
|
||||
train_sec = cm.setdefault("train", CommentedMap())
|
||||
|
||||
for section in ["test", "log"]:
|
||||
if section in cm and isinstance(cm[section], dict):
|
||||
for k, v in cm[section].items():
|
||||
if k not in train_sec:
|
||||
train_sec[k] = v
|
||||
print(f"[MERGE] {path}: train.{k} <- {section}.{k} ({repr(v)})")
|
||||
del cm[section]
|
||||
print(f"[DEL] {path}: deleted section '{section}'")
|
||||
|
||||
# train.debug
|
||||
if "debug" not in train_sec:
|
||||
train_sec["debug"] = False
|
||||
print(f"[ADD] {path}: train.debug = False")
|
||||
|
||||
|
||||
def fill_train_output_dim(cm: CommentedMap, train_output_defaults: dict, path: str):
|
||||
train_sec = cm.setdefault("train", CommentedMap())
|
||||
if "output_dim" not in train_sec or train_sec["output_dim"] is None:
|
||||
basic = cm.get("basic", {})
|
||||
dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None
|
||||
if dataset and dataset in train_output_defaults:
|
||||
val = train_output_defaults[dataset]
|
||||
src = f"default_from_dataset[{dataset}]"
|
||||
else:
|
||||
val = GLOBAL_TRAIN_DEFAULTS["output_dim"]
|
||||
src = "GLOBAL_TRAIN_DEFAULTS"
|
||||
train_sec["output_dim"] = val
|
||||
print(f"[FILL] {path}: train.output_dim <- {src} ({val})")
|
||||
|
||||
|
||||
def sync_train_batch_size(cm: CommentedMap, path: str):
|
||||
"""
|
||||
如果 train.batch_size 与 data.batch_size 不一致,以 data 为准
|
||||
"""
|
||||
data_sec = cm.get("data", {})
|
||||
train_sec = cm.get("train", {})
|
||||
data_bs = data_sec.get("batch_size")
|
||||
train_bs = train_sec.get("batch_size")
|
||||
|
||||
if data_bs is not None and train_bs != data_bs:
|
||||
train_sec["batch_size"] = data_bs
|
||||
print(f"[SYNC] {path}: train.batch_size corrected to match data.batch_size ({data_bs})")
|
||||
|
||||
|
||||
def sort_subkeys_and_insert_blanklines(cm: CommentedMap):
|
||||
for sec in list(cm.keys()):
|
||||
if isinstance(cm[sec], dict):
|
||||
sorted_cm = CommentedMap()
|
||||
for k in sorted(cm[sec].keys()):
|
||||
sorted_cm[k] = cm[sec][k]
|
||||
cm[sec] = sorted_cm
|
||||
|
||||
keys = list(cm.keys())
|
||||
for i, k in enumerate(keys):
|
||||
if i == 0:
|
||||
try:
|
||||
cm.yaml_set_comment_before_after_key(k, before=None)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
cm.yaml_set_comment_before_after_key(k, before="\n")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def process_all(base="."):
|
||||
print(">> Collecting dataset defaults ...")
|
||||
data_defaults, train_output_defaults = collect_dataset_defaults(base)
|
||||
print(">> Collected data defaults per dataset:")
|
||||
for ds, kv in data_defaults.items():
|
||||
print(f" - {ds}: {kv}")
|
||||
print(">> Collected train.output_dim defaults per dataset:")
|
||||
for ds, val in train_output_defaults.items():
|
||||
print(f" - {ds}: output_dim = {val}")
|
||||
|
||||
for root, _, files in os.walk(base):
|
||||
for name in files:
|
||||
if not (name.endswith(".yaml") or name.endswith(".yml")):
|
||||
continue
|
||||
path = os.path.join(root, name)
|
||||
cm = load_yaml(path)
|
||||
if not isinstance(cm, CommentedMap):
|
||||
print(f"[SKIP] {path}: top-level not mapping or load failed")
|
||||
continue
|
||||
|
||||
ensure_basic_seed(cm, path)
|
||||
fill_data_defaults(cm, data_defaults, path)
|
||||
merge_test_log_into_train(cm, path)
|
||||
fill_train_output_dim(cm, train_output_defaults, path)
|
||||
sync_train_batch_size(cm, path) # <-- 新增逻辑
|
||||
sort_subkeys_and_insert_blanklines(cm)
|
||||
|
||||
try:
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(cm, f)
|
||||
print(f"[OK] Written: {path}")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Write failed {path}: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
process_all(".")
|
||||
|
|
@ -1,53 +1,35 @@
|
|||
import torch, torch.nn as nn, torch.nn.functional as F
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import BatchNorm2d, Conv1d, Conv2d, ModuleList, Parameter
|
||||
import torch.nn.functional as F
|
||||
|
||||
def nconv(x, A):
|
||||
"""Multiply x by adjacency matrix along source node axis"""
|
||||
return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous()
|
||||
|
||||
|
||||
class nconv(nn.Module):
|
||||
"""
|
||||
图卷积操作的实现类
|
||||
使用einsum进行矩阵运算,实现图卷积操作
|
||||
"""
|
||||
|
||||
def forward(self, x, A):
|
||||
return torch.einsum("ncvl,vw->ncwl", (x, A)).contiguous()
|
||||
|
||||
|
||||
class linear(nn.Module):
|
||||
"""
|
||||
线性变换层
|
||||
使用1x1卷积实现线性变换
|
||||
"""
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super().__init__()
|
||||
self.mlp = nn.Conv2d(c_in, c_out, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
class gcn(nn.Module):
|
||||
"""
|
||||
图卷积网络层
|
||||
实现高阶图卷积操作,支持多阶邻接矩阵
|
||||
"""
|
||||
|
||||
class GraphConvNet(nn.Module):
|
||||
def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
|
||||
super().__init__()
|
||||
self.nconv = nconv()
|
||||
c_in = (order * support_len + 1) * c_in
|
||||
self.mlp, self.dropout, self.order = linear(c_in, c_out), dropout, order
|
||||
self.final_conv = Conv2d(c_in, c_out, (1, 1), padding=(0, 0), stride=(1, 1), bias=True)
|
||||
self.dropout = dropout
|
||||
self.order = order
|
||||
|
||||
def forward(self, x, support):
|
||||
def forward(self, x, support: list):
|
||||
out = [x]
|
||||
for a in support:
|
||||
x1 = self.nconv(x, a)
|
||||
x1 = nconv(x, a)
|
||||
out.append(x1)
|
||||
for _ in range(2, self.order + 1):
|
||||
x1 = self.nconv(x1, a)
|
||||
out.append(x1)
|
||||
return F.dropout(
|
||||
self.mlp(torch.cat(out, dim=1)), self.dropout, training=self.training
|
||||
)
|
||||
for k in range(2, self.order + 1):
|
||||
x2 = nconv(x1, a)
|
||||
out.append(x2)
|
||||
x1 = x2
|
||||
|
||||
h = torch.cat(out, dim=1)
|
||||
h = self.final_conv(h)
|
||||
h = F.dropout(h, self.dropout, training=self.training)
|
||||
return h
|
||||
|
||||
|
||||
class gwnet(nn.Module):
|
||||
|
|
@ -59,126 +41,121 @@ class gwnet(nn.Module):
|
|||
def __init__(self, args):
|
||||
super().__init__()
|
||||
# 初始化基本参数
|
||||
self.dropout, self.blocks, self.layers = (
|
||||
args["dropout"],
|
||||
args["blocks"],
|
||||
args["layers"],
|
||||
)
|
||||
self.gcn_bool, self.addaptadj = args["gcn_bool"], args["addaptadj"]
|
||||
self.dropout = args["dropout"]
|
||||
self.blocks = args["blocks"]
|
||||
self.layers = args["layers"]
|
||||
self.do_graph_conv = args.get("do_graph_conv", True)
|
||||
self.cat_feat_gc = args.get("cat_feat_gc", False)
|
||||
self.addaptadj = args.get("addaptadj", True)
|
||||
supports = None
|
||||
aptinit = args.get("aptinit", None)
|
||||
in_dim = args.get("in_dim")
|
||||
out_dim = args.get("out_dim")
|
||||
residual_channels = args.get("residual_channels")
|
||||
dilation_channels = args.get("dilation_channels")
|
||||
skip_channels = args.get("skip_channels")
|
||||
end_channels = args.get("end_channels")
|
||||
kernel_size = args.get("kernel_size")
|
||||
apt_size = args.get("apt_size", 10)
|
||||
|
||||
# 初始化各种卷积层和模块
|
||||
self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList()
|
||||
self.residual_convs, self.skip_convs, self.bn, self.gconv = (
|
||||
nn.ModuleList(),
|
||||
nn.ModuleList(),
|
||||
nn.ModuleList(),
|
||||
nn.ModuleList(),
|
||||
)
|
||||
self.start_conv = nn.Conv2d(args["in_dim"], args["residual_channels"], 1)
|
||||
self.supports = args.get("supports", None)
|
||||
|
||||
# 计算感受野
|
||||
if self.cat_feat_gc:
|
||||
self.start_conv = nn.Conv2d(in_channels=1, # hard code to avoid errors
|
||||
out_channels=residual_channels,
|
||||
kernel_size=(1, 1))
|
||||
self.cat_feature_conv = nn.Conv2d(in_channels=in_dim - 1,
|
||||
out_channels=residual_channels,
|
||||
kernel_size=(1, 1))
|
||||
else:
|
||||
self.start_conv = nn.Conv2d(in_channels=in_dim,
|
||||
out_channels=residual_channels,
|
||||
kernel_size=(1, 1))
|
||||
|
||||
self.fixed_supports = supports or []
|
||||
receptive_field = 1
|
||||
self.supports_len = len(self.supports) if self.supports is not None else 0
|
||||
|
||||
# 如果使用自适应邻接矩阵,初始化相关参数
|
||||
if self.gcn_bool and self.addaptadj:
|
||||
aptinit = args.get("aptinit", None)
|
||||
self.supports_len = len(self.fixed_supports)
|
||||
if self.do_graph_conv and self.addaptadj:
|
||||
if aptinit is None:
|
||||
if self.supports is None:
|
||||
self.supports = []
|
||||
self.nodevec1 = nn.Parameter(
|
||||
torch.randn(args["num_nodes"], 10, device=args["device"])
|
||||
)
|
||||
self.nodevec2 = nn.Parameter(
|
||||
torch.randn(10, args["num_nodes"], device=args["device"])
|
||||
)
|
||||
self.supports_len += 1
|
||||
nodevecs = torch.randn(args["num_nodes"], apt_size), torch.randn(apt_size, args["num_nodes"])
|
||||
else:
|
||||
if self.supports is None:
|
||||
self.supports = []
|
||||
m, p, n = torch.svd(aptinit)
|
||||
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
|
||||
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
|
||||
self.nodevec1 = nn.Parameter(initemb1)
|
||||
self.nodevec2 = nn.Parameter(initemb2)
|
||||
self.supports_len += 1
|
||||
nodevecs = self.svd_init(args["num_nodes"], apt_size, aptinit)
|
||||
self.supports_len += 1
|
||||
self.nodevec1, self.nodevec2 = [Parameter(n.to(args["device"]), requires_grad=True) for n in nodevecs]
|
||||
|
||||
# 获取模型参数
|
||||
ks, res, dil, skip, endc, out_dim = (
|
||||
args["kernel_size"],
|
||||
args["residual_channels"],
|
||||
args["dilation_channels"],
|
||||
args["skip_channels"],
|
||||
args["end_channels"],
|
||||
args["out_dim"],
|
||||
)
|
||||
depth = list(range(self.blocks * self.layers))
|
||||
|
||||
# 构建模型层
|
||||
# 1x1 convolution for residual and skip connections (slightly different see docstring)
|
||||
self.residual_convs = ModuleList([Conv2d(dilation_channels, residual_channels, (1, 1)) for _ in depth])
|
||||
self.skip_convs = ModuleList([Conv2d(dilation_channels, skip_channels, (1, 1)) for _ in depth])
|
||||
self.bn = ModuleList([BatchNorm2d(residual_channels) for _ in depth])
|
||||
self.graph_convs = ModuleList([GraphConvNet(dilation_channels, residual_channels, self.dropout, support_len=self.supports_len)
|
||||
for _ in depth])
|
||||
|
||||
self.filter_convs = ModuleList()
|
||||
self.gate_convs = ModuleList()
|
||||
for b in range(self.blocks):
|
||||
add_scope, new_dil = ks - 1, 1
|
||||
additional_scope = kernel_size - 1
|
||||
D = 1 # dilation
|
||||
for i in range(self.layers):
|
||||
# 添加时间卷积层
|
||||
self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
|
||||
self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
|
||||
self.residual_convs.append(nn.Conv2d(dil, res, 1))
|
||||
self.skip_convs.append(nn.Conv2d(dil, skip, 1))
|
||||
self.bn.append(nn.BatchNorm2d(res))
|
||||
new_dil *= 2
|
||||
receptive_field += add_scope
|
||||
add_scope *= 2
|
||||
if self.gcn_bool:
|
||||
self.gconv.append(
|
||||
gcn(dil, res, args["dropout"], support_len=self.supports_len)
|
||||
)
|
||||
|
||||
# 输出层
|
||||
self.end_conv_1 = nn.Conv2d(skip, endc, 1)
|
||||
self.end_conv_2 = nn.Conv2d(endc, out_dim, 1)
|
||||
# dilated convolutions
|
||||
self.filter_convs.append(Conv2d(residual_channels, dilation_channels, (1, kernel_size), dilation=D))
|
||||
self.gate_convs.append(Conv2d(residual_channels, dilation_channels, (1, kernel_size), dilation=D))
|
||||
D *= 2
|
||||
receptive_field += additional_scope
|
||||
additional_scope *= 2
|
||||
self.receptive_field = receptive_field
|
||||
|
||||
self.end_conv_1 = Conv2d(skip_channels, end_channels, (1, 1), bias=True)
|
||||
self.end_conv_2 = Conv2d(end_channels, out_dim, (1, 1), bias=True)
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
前向传播函数
|
||||
实现模型的推理过程
|
||||
"""
|
||||
# 数据预处理
|
||||
input = input[..., 0:2].transpose(1, 3)
|
||||
input = F.pad(input, (1, 0, 0, 0))
|
||||
in_len = input.size(3)
|
||||
x = (
|
||||
F.pad(input, (self.receptive_field - in_len, 0, 0, 0))
|
||||
if in_len < self.receptive_field
|
||||
else input
|
||||
)
|
||||
|
||||
# 初始卷积
|
||||
x, skip, new_supports = self.start_conv(x), 0, None
|
||||
|
||||
# 如果使用自适应邻接矩阵,计算新的邻接矩阵
|
||||
if self.gcn_bool and self.addaptadj and self.supports is not None:
|
||||
x = input[..., 0:1].transpose(1, 3)
|
||||
# Input shape is (bs, features, n_nodes, n_timesteps)
|
||||
in_len = x.size(3)
|
||||
if in_len < self.receptive_field:
|
||||
x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0))
|
||||
if self.cat_feat_gc:
|
||||
f1, f2 = x[:, [0]], x[:, 1:]
|
||||
x1 = self.start_conv(f1)
|
||||
x2 = F.leaky_relu(self.cat_feature_conv(f2))
|
||||
x = x1 + x2
|
||||
else:
|
||||
x = self.start_conv(x)
|
||||
skip = 0
|
||||
adjacency_matrices = self.fixed_supports
|
||||
# calculate the current adaptive adj matrix once per iteration
|
||||
if self.addaptadj:
|
||||
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
|
||||
new_supports = self.supports + [adp]
|
||||
adjacency_matrices = self.fixed_supports + [adp]
|
||||
|
||||
# 主网络层的前向传播
|
||||
# WaveNet layers
|
||||
for i in range(self.blocks * self.layers):
|
||||
residual = x
|
||||
# 时间卷积操作
|
||||
f = self.filter_convs[i](residual).tanh()
|
||||
g = self.gate_convs[i](residual).sigmoid()
|
||||
x = f * g
|
||||
s = self.skip_convs[i](x)
|
||||
skip = (
|
||||
skip[:, :, :, -s.size(3) :] if isinstance(skip, torch.Tensor) else 0
|
||||
) + s
|
||||
# dilated convolution
|
||||
filter = torch.tanh(self.filter_convs[i](residual))
|
||||
gate = torch.sigmoid(self.gate_convs[i](residual))
|
||||
x = filter * gate
|
||||
# parametrized skip connection
|
||||
s = self.skip_convs[i](x) # what are we skipping??
|
||||
try: # if i > 0 this works
|
||||
skip = skip[:, :, :, -s.size(3):] # TODO(SS): Mean/Max Pool?
|
||||
except:
|
||||
skip = 0
|
||||
skip = s + skip
|
||||
if i == (self.blocks * self.layers - 1): # last X getting ignored anyway
|
||||
break
|
||||
|
||||
# 图卷积操作
|
||||
if self.gcn_bool and self.supports is not None:
|
||||
x = self.gconv[i](x, new_supports if self.addaptadj else self.supports)
|
||||
if self.do_graph_conv:
|
||||
graph_out = self.graph_convs[i](x, adjacency_matrices)
|
||||
x = x + graph_out if self.cat_feat_gc else graph_out
|
||||
else:
|
||||
x = self.residual_convs[i](x)
|
||||
x = x + residual[:, :, :, -x.size(3) :]
|
||||
x = x + residual[:, :, :, -x.size(3):] # TODO(SS): Mean/Max Pool?
|
||||
x = self.bn[i](x)
|
||||
|
||||
# 输出层处理
|
||||
return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip))))
|
||||
x = F.relu(skip) # ignore last X?
|
||||
x = F.relu(self.end_conv_1(x))
|
||||
x = self.end_conv_2(x) # downsample to (bs, seq_length, 207, nfeatures)
|
||||
# x = x.transpose(1, 3)
|
||||
return x
|
||||
|
|
|
|||
|
|
@ -1,97 +1,98 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
import sys
|
||||
import torch, torch.nn as nn, torch.nn.functional as F
|
||||
|
||||
|
||||
class nconv(nn.Module):
|
||||
def __init__(self):
|
||||
super(nconv, self).__init__()
|
||||
"""
|
||||
图卷积操作的实现类
|
||||
使用einsum进行矩阵运算,实现图卷积操作
|
||||
"""
|
||||
|
||||
def forward(self, x, A):
|
||||
x = torch.einsum("ncvl,vw->ncwl", (x, A))
|
||||
return x.contiguous()
|
||||
return torch.einsum("ncvl,vw->ncwl", (x, A)).contiguous()
|
||||
|
||||
|
||||
class linear(nn.Module):
|
||||
"""
|
||||
线性变换层
|
||||
使用1x1卷积实现线性变换
|
||||
"""
|
||||
|
||||
def __init__(self, c_in, c_out):
|
||||
super(linear, self).__init__()
|
||||
self.mlp = torch.nn.Conv2d(
|
||||
c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True
|
||||
)
|
||||
super().__init__()
|
||||
self.mlp = nn.Conv2d(c_in, c_out, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
class gcn(nn.Module):
|
||||
"""
|
||||
图卷积网络层
|
||||
实现高阶图卷积操作,支持多阶邻接矩阵
|
||||
"""
|
||||
|
||||
def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
|
||||
super(gcn, self).__init__()
|
||||
super().__init__()
|
||||
self.nconv = nconv()
|
||||
c_in = (order * support_len + 1) * c_in
|
||||
self.mlp = linear(c_in, c_out)
|
||||
self.dropout = dropout
|
||||
self.order = order
|
||||
self.mlp, self.dropout, self.order = linear(c_in, c_out), dropout, order
|
||||
|
||||
def forward(self, x, support):
|
||||
out = [x]
|
||||
for a in support:
|
||||
x1 = self.nconv(x, a)
|
||||
out.append(x1)
|
||||
for k in range(2, self.order + 1):
|
||||
x2 = self.nconv(x1, a)
|
||||
out.append(x2)
|
||||
x1 = x2
|
||||
|
||||
h = torch.cat(out, dim=1)
|
||||
h = self.mlp(h)
|
||||
h = F.dropout(h, self.dropout, training=self.training)
|
||||
return h
|
||||
for _ in range(2, self.order + 1):
|
||||
x1 = self.nconv(x1, a)
|
||||
out.append(x1)
|
||||
return F.dropout(
|
||||
self.mlp(torch.cat(out, dim=1)), self.dropout, training=self.training
|
||||
)
|
||||
|
||||
|
||||
class gwnet(nn.Module):
|
||||
"""
|
||||
Graph WaveNet模型的主类
|
||||
结合了图卷积网络和时序卷积网络,用于时空预测任务
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
super(gwnet, self).__init__()
|
||||
self.dropout = args["dropout"]
|
||||
self.blocks = args["blocks"]
|
||||
self.layers = args["layers"]
|
||||
self.gcn_bool = args["gcn_bool"]
|
||||
self.addaptadj = args["addaptadj"]
|
||||
|
||||
self.filter_convs = nn.ModuleList()
|
||||
self.gate_convs = nn.ModuleList()
|
||||
self.residual_convs = nn.ModuleList()
|
||||
self.skip_convs = nn.ModuleList()
|
||||
self.bn = nn.ModuleList()
|
||||
self.gconv = nn.ModuleList()
|
||||
|
||||
self.start_conv = nn.Conv2d(
|
||||
in_channels=args["in_dim"],
|
||||
out_channels=args["residual_channels"],
|
||||
kernel_size=(1, 1),
|
||||
super().__init__()
|
||||
# 初始化基本参数
|
||||
self.dropout, self.blocks, self.layers = (
|
||||
args["dropout"],
|
||||
args["blocks"],
|
||||
args["layers"],
|
||||
)
|
||||
self.gcn_bool, self.addaptadj = args["gcn_bool"], args["addaptadj"]
|
||||
|
||||
# 初始化各种卷积层和模块
|
||||
self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList()
|
||||
self.residual_convs, self.skip_convs, self.bn, self.gconv = (
|
||||
nn.ModuleList(),
|
||||
nn.ModuleList(),
|
||||
nn.ModuleList(),
|
||||
nn.ModuleList(),
|
||||
)
|
||||
self.start_conv = nn.Conv2d(args["in_dim"], args["residual_channels"], 1)
|
||||
self.supports = args.get("supports", None)
|
||||
|
||||
# 计算感受野
|
||||
receptive_field = 1
|
||||
self.supports_len = len(self.supports) if self.supports is not None else 0
|
||||
|
||||
self.supports_len = 0
|
||||
if self.supports is not None:
|
||||
self.supports_len += len(self.supports)
|
||||
|
||||
# 如果使用自适应邻接矩阵,初始化相关参数
|
||||
if self.gcn_bool and self.addaptadj:
|
||||
aptinit = args.get("aptinit", None)
|
||||
if aptinit is None:
|
||||
if self.supports is None:
|
||||
self.supports = []
|
||||
self.nodevec1 = nn.Parameter(
|
||||
torch.randn(args["num_nodes"], 10).to(args["device"]),
|
||||
requires_grad=True,
|
||||
).to(args["device"])
|
||||
torch.randn(args["num_nodes"], 10, device=args["device"])
|
||||
)
|
||||
self.nodevec2 = nn.Parameter(
|
||||
torch.randn(10, args["num_nodes"]).to(args["device"]),
|
||||
requires_grad=True,
|
||||
).to(args["device"])
|
||||
torch.randn(10, args["num_nodes"], device=args["device"])
|
||||
)
|
||||
self.supports_len += 1
|
||||
else:
|
||||
if self.supports is None:
|
||||
|
|
@ -99,156 +100,85 @@ class gwnet(nn.Module):
|
|||
m, p, n = torch.svd(aptinit)
|
||||
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
|
||||
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
|
||||
self.nodevec1 = nn.Parameter(initemb1, requires_grad=True).to(
|
||||
args["device"]
|
||||
)
|
||||
self.nodevec2 = nn.Parameter(initemb2, requires_grad=True).to(
|
||||
args["device"]
|
||||
)
|
||||
self.nodevec1 = nn.Parameter(initemb1)
|
||||
self.nodevec2 = nn.Parameter(initemb2)
|
||||
self.supports_len += 1
|
||||
|
||||
kernel_size = args["kernel_size"]
|
||||
residual_channels = args["residual_channels"]
|
||||
dilation_channels = args["dilation_channels"]
|
||||
kernel_size = args["kernel_size"]
|
||||
skip_channels = args["skip_channels"]
|
||||
end_channels = args["end_channels"]
|
||||
out_dim = args["out_dim"]
|
||||
dropout = args["dropout"]
|
||||
# 获取模型参数
|
||||
ks, res, dil, skip, endc, out_dim = (
|
||||
args["kernel_size"],
|
||||
args["residual_channels"],
|
||||
args["dilation_channels"],
|
||||
args["skip_channels"],
|
||||
args["end_channels"],
|
||||
args["out_dim"],
|
||||
)
|
||||
|
||||
# 构建模型层
|
||||
for b in range(self.blocks):
|
||||
additional_scope = kernel_size - 1
|
||||
new_dilation = 1
|
||||
add_scope, new_dil = ks - 1, 1
|
||||
for i in range(self.layers):
|
||||
# dilated convolutions
|
||||
self.filter_convs.append(
|
||||
nn.Conv2d(
|
||||
in_channels=residual_channels,
|
||||
out_channels=dilation_channels,
|
||||
kernel_size=(1, kernel_size),
|
||||
dilation=new_dilation,
|
||||
)
|
||||
)
|
||||
|
||||
self.gate_convs.append(
|
||||
nn.Conv2d(
|
||||
in_channels=residual_channels,
|
||||
out_channels=dilation_channels,
|
||||
kernel_size=(1, kernel_size),
|
||||
dilation=new_dilation,
|
||||
)
|
||||
)
|
||||
|
||||
# 1x1 convolution for residual connection
|
||||
self.residual_convs.append(
|
||||
nn.Conv2d(
|
||||
in_channels=dilation_channels,
|
||||
out_channels=residual_channels,
|
||||
kernel_size=(1, 1),
|
||||
)
|
||||
)
|
||||
|
||||
# 1x1 convolution for skip connection
|
||||
self.skip_convs.append(
|
||||
nn.Conv2d(
|
||||
in_channels=dilation_channels,
|
||||
out_channels=skip_channels,
|
||||
kernel_size=(1, 1),
|
||||
)
|
||||
)
|
||||
self.bn.append(nn.BatchNorm2d(residual_channels))
|
||||
new_dilation *= 2
|
||||
receptive_field += additional_scope
|
||||
additional_scope *= 2
|
||||
# 添加时间卷积层
|
||||
self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
|
||||
self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
|
||||
self.residual_convs.append(nn.Conv2d(dil, res, 1))
|
||||
self.skip_convs.append(nn.Conv2d(dil, skip, 1))
|
||||
self.bn.append(nn.BatchNorm2d(res))
|
||||
new_dil *= 2
|
||||
receptive_field += add_scope
|
||||
add_scope *= 2
|
||||
if self.gcn_bool:
|
||||
self.gconv.append(
|
||||
gcn(
|
||||
dilation_channels,
|
||||
residual_channels,
|
||||
dropout,
|
||||
support_len=self.supports_len,
|
||||
)
|
||||
gcn(dil, res, args["dropout"], support_len=self.supports_len)
|
||||
)
|
||||
|
||||
self.end_conv_1 = nn.Conv2d(
|
||||
in_channels=skip_channels,
|
||||
out_channels=end_channels,
|
||||
kernel_size=(1, 1),
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.end_conv_2 = nn.Conv2d(
|
||||
in_channels=end_channels,
|
||||
out_channels=out_dim,
|
||||
kernel_size=(1, 1),
|
||||
bias=True,
|
||||
)
|
||||
|
||||
# 输出层
|
||||
self.end_conv_1 = nn.Conv2d(skip, endc, 1)
|
||||
self.end_conv_2 = nn.Conv2d(endc, out_dim, 1)
|
||||
self.receptive_field = receptive_field
|
||||
|
||||
def forward(self, input):
|
||||
input = input[..., 0:2]
|
||||
input = input.transpose(1, 3)
|
||||
input = nn.functional.pad(input, (1, 0, 0, 0))
|
||||
"""
|
||||
前向传播函数
|
||||
实现模型的推理过程
|
||||
"""
|
||||
# 数据预处理
|
||||
input = input[..., 0:2].transpose(1, 3)
|
||||
input = F.pad(input, (1, 0, 0, 0))
|
||||
in_len = input.size(3)
|
||||
if in_len < self.receptive_field:
|
||||
x = nn.functional.pad(input, (self.receptive_field - in_len, 0, 0, 0))
|
||||
else:
|
||||
x = input
|
||||
x = self.start_conv(x)
|
||||
skip = 0
|
||||
x = (
|
||||
F.pad(input, (self.receptive_field - in_len, 0, 0, 0))
|
||||
if in_len < self.receptive_field
|
||||
else input
|
||||
)
|
||||
|
||||
# calculate the current adaptive adj matrix once per iteration
|
||||
new_supports = None
|
||||
# 初始卷积
|
||||
x, skip, new_supports = self.start_conv(x), 0, None
|
||||
|
||||
# 如果使用自适应邻接矩阵,计算新的邻接矩阵
|
||||
if self.gcn_bool and self.addaptadj and self.supports is not None:
|
||||
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
|
||||
new_supports = self.supports + [adp]
|
||||
|
||||
# WaveNet layers
|
||||
# 主网络层的前向传播
|
||||
for i in range(self.blocks * self.layers):
|
||||
# |----------------------------------------| *residual*
|
||||
# | |
|
||||
# | |-- conv -- tanh --| |
|
||||
# -> dilate -|----| * ----|-- 1x1 -- + --> *input*
|
||||
# |-- conv -- sigm --| |
|
||||
# 1x1
|
||||
# |
|
||||
# ---------------------------------------> + -------------> *skip*
|
||||
|
||||
# (dilation, init_dilation) = self.dilations[i]
|
||||
|
||||
# residual = dilation_func(x, dilation, init_dilation, i)
|
||||
residual = x
|
||||
# dilated convolution
|
||||
filter = self.filter_convs[i](residual)
|
||||
filter = torch.tanh(filter)
|
||||
gate = self.gate_convs[i](residual)
|
||||
gate = torch.sigmoid(gate)
|
||||
x = filter * gate
|
||||
|
||||
# parametrized skip connection
|
||||
|
||||
s = x
|
||||
s = self.skip_convs[i](s)
|
||||
try:
|
||||
skip = skip[:, :, :, -s.size(3) :]
|
||||
except:
|
||||
skip = 0
|
||||
skip = s + skip
|
||||
# 时间卷积操作
|
||||
f = self.filter_convs[i](residual).tanh()
|
||||
g = self.gate_convs[i](residual).sigmoid()
|
||||
x = f * g
|
||||
s = self.skip_convs[i](x)
|
||||
skip = (
|
||||
skip[:, :, :, -s.size(3) :] if isinstance(skip, torch.Tensor) else 0
|
||||
) + s
|
||||
|
||||
# 图卷积操作
|
||||
if self.gcn_bool and self.supports is not None:
|
||||
if self.addaptadj:
|
||||
x = self.gconv[i](x, new_supports)
|
||||
else:
|
||||
x = self.gconv[i](x, self.supports)
|
||||
x = self.gconv[i](x, new_supports if self.addaptadj else self.supports)
|
||||
else:
|
||||
x = self.residual_convs[i](x)
|
||||
|
||||
x = x + residual[:, :, :, -x.size(3) :]
|
||||
|
||||
x = self.bn[i](x)
|
||||
|
||||
x = F.relu(skip)
|
||||
x = F.relu(self.end_conv_1(x))
|
||||
x = self.end_conv_2(x)
|
||||
return x
|
||||
# 输出层处理
|
||||
return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip))))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,95 @@
|
|||
#!/bin/bash
|
||||
|
||||
# 设置默认模型名和数据集列表
|
||||
MODEL_NAME="GWN"
|
||||
DATASETS=(
|
||||
"METR-LA"
|
||||
"PEMS-BAY"
|
||||
"NYCBike-InFlow"
|
||||
"NYCBike-OutFlow"
|
||||
"AirQuality"
|
||||
"SolarEnergy"
|
||||
)
|
||||
|
||||
# 初始化统计变量
|
||||
success_count=0
|
||||
failure_count=0
|
||||
missing_count=0
|
||||
total_count=0
|
||||
success_datasets=()
|
||||
failure_datasets=()
|
||||
missing_datasets=()
|
||||
|
||||
# 检查是否有参数传入来覆盖默认值
|
||||
if [ $# -gt 0 ]; then
|
||||
MODEL_NAME=$1
|
||||
# 如果传入了更多参数,使用它们作为数据集列表
|
||||
if [ $# -gt 1 ]; then
|
||||
DATASETS=(${@:2})
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "使用模型: $MODEL_NAME"
|
||||
echo "数据集列表: ${DATASETS[*]}"
|
||||
echo "开始测试..."
|
||||
echo ""
|
||||
|
||||
# 循环测试每个数据集
|
||||
for dataset in "${DATASETS[@]}"; do
|
||||
total_count=$((total_count + 1))
|
||||
# 构建配置文件路径
|
||||
CONFIG_PATH="config/${MODEL_NAME}/${dataset}.yaml"
|
||||
|
||||
echo "测试数据集: $dataset"
|
||||
echo "使用配置文件: $CONFIG_PATH"
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if [ ! -f "$CONFIG_PATH" ]; then
|
||||
echo "错误: 配置文件 $CONFIG_PATH 不存在!"
|
||||
missing_count=$((missing_count + 1))
|
||||
missing_datasets+=("$dataset")
|
||||
echo "----------------------------------------"
|
||||
continue
|
||||
fi
|
||||
|
||||
# 执行测试命令并捕获输出
|
||||
echo "执行: python run.py --config $CONFIG_PATH"
|
||||
output=$(python run.py --config "$CONFIG_PATH" 2>&1)
|
||||
|
||||
# 如果没有找到明确的标记,回退到检查退出码
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "数据集 $dataset 测试成功! (基于退出码)"
|
||||
success_count=$((success_count + 1))
|
||||
success_datasets+=("$dataset")
|
||||
else
|
||||
echo "数据集 $dataset 测试失败! (基于退出码)"
|
||||
failure_count=$((failure_count + 1))
|
||||
failure_datasets+=("$dataset")
|
||||
fi
|
||||
|
||||
echo "----------------------------------------"
|
||||
done
|
||||
|
||||
# 输出总结
|
||||
echo "======================================="
|
||||
echo "测试总结"
|
||||
echo "======================================="
|
||||
echo "总数据集数量: $total_count"
|
||||
echo "成功数量: $success_count"
|
||||
echo "失败数量: $failure_count"
|
||||
echo "缺失配置文件数量: $missing_count"
|
||||
|
||||
if [ ${#success_datasets[@]} -gt 0 ]; then
|
||||
echo "成功的数据集: ${success_datasets[*]}"
|
||||
fi
|
||||
|
||||
if [ ${#failure_datasets[@]} -gt 0 ]; then
|
||||
echo "失败的数据集: ${failure_datasets[*]}"
|
||||
fi
|
||||
|
||||
if [ ${#missing_datasets[@]} -gt 0 ]; then
|
||||
echo "缺失配置的数据集: ${missing_datasets[*]}"
|
||||
fi
|
||||
|
||||
echo "======================================="
|
||||
echo "所有测试完成!"
|
||||
|
|
@ -177,6 +177,14 @@ class Trainer:
|
|||
# 前向传播
|
||||
label = target[..., : self.args["output_dim"]]
|
||||
output = self.model(data).to(self.device)
|
||||
# if output.shape != label.shape:
|
||||
# import sys
|
||||
# print(f"[Wrong]: Output shape: {output.shape}, Label shape: {label.shape}")
|
||||
# sys.exit(1)
|
||||
# else:
|
||||
# import sys
|
||||
# print(f"[Right]: Output shape: {output.shape}, Label shape: {label.shape}")
|
||||
# sys.exit(0)
|
||||
loss = self.loss(output, label)
|
||||
|
||||
# 反归一化
|
||||
|
|
|
|||
Loading…
Reference in New Issue