适配GraphWaveNet

This commit is contained in:
czzhangheng 2025-12-03 12:05:02 +08:00
parent 140ead3975
commit a9313390ac
17 changed files with 448 additions and 603 deletions

View File

@ -6,40 +6,41 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 16 batch_size: 64
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
input_dim: 6 input_dim: 6
lag: 24 lag: 24
normalizer: std normalizer: std
num_nodes: 12 num_nodes: 35
steps_per_day: 24 steps_per_day: 24
test_ratio: 0.2 test_ratio: 0.2
val_ratio: 0.2 val_ratio: 0.2
model: model:
addaptadj: true addaptadj: true
apt_size: 10
aptinit: null aptinit: null
batch_size: 16 batch_size: 64
blocks: 4 blocks: 4
dilation_channels: 32 dilation_channels: 32
dropout: 0.3 dropout: 0.3
do_graph_conv: True
end_channels: 512 end_channels: 512
gcn_bool: true gcn_bool: true
in_dim: 2 in_dim: 1
input_dim: 6 input_dim: 1
kernel_size: 2 kernel_size: 2
layers: 2 layers: 4
num_nodes: 12 num_nodes: 35
out_dim: 12 out_dim: 24
output_dim: 6
residual_channels: 32 residual_channels: 32
skip_channels: 256 skip_channels: 256
supports: null supports: null
train: train:
batch_size: 16 batch_size: 64
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -54,7 +55,7 @@ train:
mae_thresh: 0.0 mae_thresh: 0.0
mape_thresh: 0.0 mape_thresh: 0.0
max_grad_norm: 5 max_grad_norm: 5
output_dim: 6 output_dim: 1
plot: false plot: false
real_value: true real_value: true
weight_decay: 0 weight_decay: 0

View File

@ -20,24 +20,26 @@ data:
model: model:
addaptadj: true addaptadj: true
apt_size: 10
aptinit: null aptinit: null
batch_size: 32 batch_size: 16
blocks: 4 blocks: 4
dilation_channels: 32 dilation_channels: 32
dropout: 0.3 dropout: 0.3
do_graph_conv: True
end_channels: 512 end_channels: 512
gcn_bool: true gcn_bool: true
in_dim: 2 in_dim: 1
input_dim: 1 input_dim: 1
kernel_size: 2 kernel_size: 2
layers: 2 layers: 4
num_nodes: 1024 num_nodes: 1024
out_dim: 12 out_dim: 24
output_dim: 1
residual_channels: 32 residual_channels: 32
skip_channels: 256 skip_channels: 256
supports: null supports: null
train: train:
batch_size: 32 batch_size: 32
debug: false debug: false

View File

@ -20,20 +20,21 @@ data:
model: model:
addaptadj: true addaptadj: true
apt_size: 10
aptinit: null aptinit: null
batch_size: 32 batch_size: 32
blocks: 4 blocks: 4
dilation_channels: 32 dilation_channels: 32
dropout: 0.3 dropout: 0.3
do_graph_conv: True
end_channels: 512 end_channels: 512
gcn_bool: true gcn_bool: true
in_dim: 2 in_dim: 1
input_dim: 1 input_dim: 1
kernel_size: 2 kernel_size: 2
layers: 2 layers: 4
num_nodes: 1024 num_nodes: 1024
out_dim: 12 out_dim: 24
output_dim: 1
residual_channels: 32 residual_channels: 32
skip_channels: 256 skip_channels: 256
supports: null supports: null

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 16 batch_size: 64
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -20,26 +20,27 @@ data:
model: model:
addaptadj: true addaptadj: true
apt_size: 10
aptinit: null aptinit: null
batch_size: 16 batch_size: 64
blocks: 4 blocks: 4
dilation_channels: 32 dilation_channels: 32
dropout: 0.3 dropout: 0.3
do_graph_conv: True
end_channels: 512 end_channels: 512
gcn_bool: true gcn_bool: true
in_dim: 2 in_dim: 1
input_dim: 1 input_dim: 1
kernel_size: 2 kernel_size: 2
layers: 2 layers: 4
num_nodes: 207 num_nodes: 207
out_dim: 12 out_dim: 24
output_dim: 1
residual_channels: 32 residual_channels: 32
skip_channels: 256 skip_channels: 256
supports: null supports: null
train: train:
batch_size: 16 batch_size: 64
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15

View File

@ -20,20 +20,21 @@ data:
model: model:
addaptadj: true addaptadj: true
apt_size: 10
aptinit: null aptinit: null
batch_size: 32 batch_size: 32
blocks: 4 blocks: 4
dilation_channels: 32 dilation_channels: 32
dropout: 0.3 dropout: 0.3
do_graph_conv: True
end_channels: 512 end_channels: 512
gcn_bool: true gcn_bool: true
in_dim: 2 in_dim: 1
input_dim: 1 input_dim: 1
kernel_size: 2 kernel_size: 2
layers: 2 layers: 4
num_nodes: 128 num_nodes: 128
out_dim: 12 out_dim: 24
output_dim: 1
residual_channels: 32 residual_channels: 32
skip_channels: 256 skip_channels: 256
supports: null supports: null

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 32 batch_size: 16
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -20,26 +20,27 @@ data:
model: model:
addaptadj: true addaptadj: true
apt_size: 10
aptinit: null aptinit: null
batch_size: 32 batch_size: 16
blocks: 4 blocks: 4
dilation_channels: 32 dilation_channels: 32
dropout: 0.3 dropout: 0.3
do_graph_conv: True
end_channels: 512 end_channels: 512
gcn_bool: true gcn_bool: true
in_dim: 2 in_dim: 1
input_dim: 1 input_dim: 1
kernel_size: 2 kernel_size: 2
layers: 2 layers: 4
num_nodes: 128 num_nodes: 128
out_dim: 12 out_dim: 24
output_dim: 1
residual_channels: 32 residual_channels: 32
skip_channels: 256 skip_channels: 256
supports: null supports: null
train: train:
batch_size: 32 batch_size: 16
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15

61
config/GWN/PEMS-BAY.yaml Normal file
View File

@ -0,0 +1,61 @@
basic:
dataset: PEMS-BAY
device: cuda:0
mode: train
model: GWN
seed: 2023
data:
batch_size: 64
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 325
steps_per_day: 288
test_ratio: 0.2
val_ratio: 0.2
model:
addaptadj: true
apt_size: 10
aptinit: null
batch_size: 64
blocks: 4
dilation_channels: 32
dropout: 0.3
do_graph_conv: True
end_channels: 512
gcn_bool: true
in_dim: 1
input_dim: 1
kernel_size: 2
layers: 4
num_nodes: 325
out_dim: 24
residual_channels: 32
skip_channels: 256
supports: null
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 300
grad_norm: false
log_step: 1000
loss_func: mae
lr_decay: false
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: 0.0
mape_thresh: 0.0
max_grad_norm: 5
output_dim: 1
plot: false
real_value: true
weight_decay: 0

View File

@ -27,7 +27,7 @@ model:
dropout: 0.3 dropout: 0.3
end_channels: 512 end_channels: 512
gcn_bool: true gcn_bool: true
in_dim: 2 in_dim: 3
input_dim: 1 input_dim: 1
kernel_size: 2 kernel_size: 2
layers: 2 layers: 2

View File

@ -27,7 +27,7 @@ model:
dropout: 0.3 dropout: 0.3
end_channels: 512 end_channels: 512
gcn_bool: true gcn_bool: true
in_dim: 2 in_dim: 1
input_dim: 1 input_dim: 1
kernel_size: 2 kernel_size: 2
layers: 2 layers: 2

View File

@ -27,7 +27,7 @@ model:
dropout: 0.3 dropout: 0.3
end_channels: 512 end_channels: 512
gcn_bool: true gcn_bool: true
in_dim: 2 in_dim: 3
input_dim: 1 input_dim: 1
kernel_size: 2 kernel_size: 2
layers: 2 layers: 2

View File

@ -27,7 +27,7 @@ model:
dropout: 0.3 dropout: 0.3
end_channels: 512 end_channels: 512
gcn_bool: true gcn_bool: true
in_dim: 2 in_dim: 3
input_dim: 1 input_dim: 1
kernel_size: 2 kernel_size: 2
layers: 2 layers: 2

View File

@ -20,20 +20,21 @@ data:
model: model:
addaptadj: true addaptadj: true
apt_size: 10
aptinit: null aptinit: null
batch_size: 64 batch_size: 32
blocks: 4 blocks: 4
dilation_channels: 32 dilation_channels: 32
dropout: 0.3 dropout: 0.3
do_graph_conv: True
end_channels: 512 end_channels: 512
gcn_bool: true gcn_bool: true
in_dim: 2 in_dim: 1
input_dim: 1 input_dim: 1
kernel_size: 2 kernel_size: 2
layers: 2 layers: 4
num_nodes: 137 num_nodes: 137
out_dim: 12 out_dim: 24
output_dim: 1
residual_channels: 32 residual_channels: 32
skip_channels: 256 skip_channels: 256
supports: null supports: null

View File

@ -1,234 +0,0 @@
#!/usr/bin/env python3
import os
from collections import defaultdict
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
yaml = YAML()
yaml.preserve_quotes = True
yaml.indent(mapping=2, sequence=4, offset=2)
# 允许的 data keys
DATA_ALLOWED_KEYS = {
"lag",
"horizon",
"num_nodes",
"steps_per_day",
"days_per_week",
"test_ratio",
"val_ratio",
"batch_size",
"input_dim",
"column_wise",
"normalizer",
}
# 全局默认值
GLOBAL_DEFAULTS = {
"lag": 24,
"horizon": 24,
"num_nodes": 1,
"steps_per_day": 24,
"days_per_week": 7,
"test_ratio": 0.2,
"val_ratio": 0.2,
"batch_size": 16,
"input_dim": 1,
"column_wise": False,
"normalizer": "std",
}
# train全局默认值
GLOBAL_TRAIN_DEFAULTS = {
"output_dim": 1
}
def load_yaml(path):
try:
with open(path, "r", encoding="utf-8") as f:
return yaml.load(f)
except Exception:
return None
def collect_dataset_defaults(base="."):
"""
收集每个数据集 data key 默认值以及 train.output_dim 默认值
"""
data_defaults = defaultdict(dict)
train_output_defaults = dict()
for root, _, files in os.walk(base):
for name in files:
if not (name.endswith(".yaml") or name.endswith(".yml")):
continue
path = os.path.join(root, name)
cm = load_yaml(path)
if not isinstance(cm, CommentedMap):
continue
basic = cm.get("basic")
if not isinstance(basic, dict):
continue
dataset = basic.get("dataset")
if dataset is None:
continue
ds = str(dataset)
# data 默认值
data_sec = cm.get("data")
if isinstance(data_sec, dict):
for key in DATA_ALLOWED_KEYS:
if key not in data_defaults[ds] and key in data_sec and data_sec[key] is not None:
data_defaults[ds][key] = data_sec[key]
# train.output_dim 默认值
train_sec = cm.get("train")
if isinstance(train_sec, dict):
val = train_sec.get("output_dim")
if val is not None and ds not in train_output_defaults:
train_output_defaults[ds] = val
return data_defaults, train_output_defaults
def ensure_basic_seed(cm: CommentedMap, path: str):
if "basic" not in cm or not isinstance(cm["basic"], dict):
cm["basic"] = CommentedMap()
basic = cm["basic"]
if "seed" not in basic:
basic["seed"] = 2023
print(f"[ADD] {path}: basic.seed = 2023")
def fill_data_defaults(cm: CommentedMap, data_defaults: dict, path: str):
if "data" not in cm or not isinstance(cm["data"], dict):
cm["data"] = CommentedMap()
data_sec = cm["data"]
basic = cm.get("basic", {})
dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None
for key in sorted(DATA_ALLOWED_KEYS):
if key in data_sec and data_sec[key] is not None:
continue
if dataset and dataset in data_defaults and key in data_defaults[dataset]:
chosen = data_defaults[dataset][key]
src = f"default_from_dataset[{dataset}]"
else:
chosen = GLOBAL_DEFAULTS[key]
src = "GLOBAL_DEFAULTS"
data_sec[key] = chosen
print(f"[FILL] {path}: data.{key} <- {src} ({repr(chosen)})")
def merge_test_log_into_train(cm: CommentedMap, path: str):
"""
test log key 合并到 train并删除 test log
同时确保 train.debug 存在
"""
train_sec = cm.setdefault("train", CommentedMap())
for section in ["test", "log"]:
if section in cm and isinstance(cm[section], dict):
for k, v in cm[section].items():
if k not in train_sec:
train_sec[k] = v
print(f"[MERGE] {path}: train.{k} <- {section}.{k} ({repr(v)})")
del cm[section]
print(f"[DEL] {path}: deleted section '{section}'")
# train.debug
if "debug" not in train_sec:
train_sec["debug"] = False
print(f"[ADD] {path}: train.debug = False")
def fill_train_output_dim(cm: CommentedMap, train_output_defaults: dict, path: str):
train_sec = cm.setdefault("train", CommentedMap())
if "output_dim" not in train_sec or train_sec["output_dim"] is None:
basic = cm.get("basic", {})
dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None
if dataset and dataset in train_output_defaults:
val = train_output_defaults[dataset]
src = f"default_from_dataset[{dataset}]"
else:
val = GLOBAL_TRAIN_DEFAULTS["output_dim"]
src = "GLOBAL_TRAIN_DEFAULTS"
train_sec["output_dim"] = val
print(f"[FILL] {path}: train.output_dim <- {src} ({val})")
def sync_train_batch_size(cm: CommentedMap, path: str):
"""
如果 train.batch_size data.batch_size 不一致 data 为准
"""
data_sec = cm.get("data", {})
train_sec = cm.get("train", {})
data_bs = data_sec.get("batch_size")
train_bs = train_sec.get("batch_size")
if data_bs is not None and train_bs != data_bs:
train_sec["batch_size"] = data_bs
print(f"[SYNC] {path}: train.batch_size corrected to match data.batch_size ({data_bs})")
def sort_subkeys_and_insert_blanklines(cm: CommentedMap):
for sec in list(cm.keys()):
if isinstance(cm[sec], dict):
sorted_cm = CommentedMap()
for k in sorted(cm[sec].keys()):
sorted_cm[k] = cm[sec][k]
cm[sec] = sorted_cm
keys = list(cm.keys())
for i, k in enumerate(keys):
if i == 0:
try:
cm.yaml_set_comment_before_after_key(k, before=None)
except Exception:
pass
else:
try:
cm.yaml_set_comment_before_after_key(k, before="\n")
except Exception:
pass
def process_all(base="."):
print(">> Collecting dataset defaults ...")
data_defaults, train_output_defaults = collect_dataset_defaults(base)
print(">> Collected data defaults per dataset:")
for ds, kv in data_defaults.items():
print(f" - {ds}: {kv}")
print(">> Collected train.output_dim defaults per dataset:")
for ds, val in train_output_defaults.items():
print(f" - {ds}: output_dim = {val}")
for root, _, files in os.walk(base):
for name in files:
if not (name.endswith(".yaml") or name.endswith(".yml")):
continue
path = os.path.join(root, name)
cm = load_yaml(path)
if not isinstance(cm, CommentedMap):
print(f"[SKIP] {path}: top-level not mapping or load failed")
continue
ensure_basic_seed(cm, path)
fill_data_defaults(cm, data_defaults, path)
merge_test_log_into_train(cm, path)
fill_train_output_dim(cm, train_output_defaults, path)
sync_train_batch_size(cm, path) # <-- 新增逻辑
sort_subkeys_and_insert_blanklines(cm)
try:
with open(path, "w", encoding="utf-8") as f:
yaml.dump(cm, f)
print(f"[OK] Written: {path}")
except Exception as e:
print(f"[ERROR] Write failed {path}: {e}")
if __name__ == "__main__":
process_all(".")

View File

@ -1,53 +1,35 @@
import torch, torch.nn as nn, torch.nn.functional as F import torch
import torch.nn as nn
from torch.nn import BatchNorm2d, Conv1d, Conv2d, ModuleList, Parameter
import torch.nn.functional as F
def nconv(x, A):
"""Multiply x by adjacency matrix along source node axis"""
return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous()
class nconv(nn.Module): class GraphConvNet(nn.Module):
"""
图卷积操作的实现类
使用einsum进行矩阵运算实现图卷积操作
"""
def forward(self, x, A):
return torch.einsum("ncvl,vw->ncwl", (x, A)).contiguous()
class linear(nn.Module):
"""
线性变换层
使用1x1卷积实现线性变换
"""
def __init__(self, c_in, c_out):
super().__init__()
self.mlp = nn.Conv2d(c_in, c_out, 1)
def forward(self, x):
return self.mlp(x)
class gcn(nn.Module):
"""
图卷积网络层
实现高阶图卷积操作支持多阶邻接矩阵
"""
def __init__(self, c_in, c_out, dropout, support_len=3, order=2): def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
super().__init__() super().__init__()
self.nconv = nconv()
c_in = (order * support_len + 1) * c_in c_in = (order * support_len + 1) * c_in
self.mlp, self.dropout, self.order = linear(c_in, c_out), dropout, order self.final_conv = Conv2d(c_in, c_out, (1, 1), padding=(0, 0), stride=(1, 1), bias=True)
self.dropout = dropout
self.order = order
def forward(self, x, support): def forward(self, x, support: list):
out = [x] out = [x]
for a in support: for a in support:
x1 = self.nconv(x, a) x1 = nconv(x, a)
out.append(x1) out.append(x1)
for _ in range(2, self.order + 1): for k in range(2, self.order + 1):
x1 = self.nconv(x1, a) x2 = nconv(x1, a)
out.append(x1) out.append(x2)
return F.dropout( x1 = x2
self.mlp(torch.cat(out, dim=1)), self.dropout, training=self.training
) h = torch.cat(out, dim=1)
h = self.final_conv(h)
h = F.dropout(h, self.dropout, training=self.training)
return h
class gwnet(nn.Module): class gwnet(nn.Module):
@ -59,126 +41,121 @@ class gwnet(nn.Module):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
# 初始化基本参数 # 初始化基本参数
self.dropout, self.blocks, self.layers = ( self.dropout = args["dropout"]
args["dropout"], self.blocks = args["blocks"]
args["blocks"], self.layers = args["layers"]
args["layers"], self.do_graph_conv = args.get("do_graph_conv", True)
) self.cat_feat_gc = args.get("cat_feat_gc", False)
self.gcn_bool, self.addaptadj = args["gcn_bool"], args["addaptadj"] self.addaptadj = args.get("addaptadj", True)
supports = None
aptinit = args.get("aptinit", None)
in_dim = args.get("in_dim")
out_dim = args.get("out_dim")
residual_channels = args.get("residual_channels")
dilation_channels = args.get("dilation_channels")
skip_channels = args.get("skip_channels")
end_channels = args.get("end_channels")
kernel_size = args.get("kernel_size")
apt_size = args.get("apt_size", 10)
# 初始化各种卷积层和模块
self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList()
self.residual_convs, self.skip_convs, self.bn, self.gconv = (
nn.ModuleList(),
nn.ModuleList(),
nn.ModuleList(),
nn.ModuleList(),
)
self.start_conv = nn.Conv2d(args["in_dim"], args["residual_channels"], 1)
self.supports = args.get("supports", None)
# 计算感受野 if self.cat_feat_gc:
self.start_conv = nn.Conv2d(in_channels=1, # hard code to avoid errors
out_channels=residual_channels,
kernel_size=(1, 1))
self.cat_feature_conv = nn.Conv2d(in_channels=in_dim - 1,
out_channels=residual_channels,
kernel_size=(1, 1))
else:
self.start_conv = nn.Conv2d(in_channels=in_dim,
out_channels=residual_channels,
kernel_size=(1, 1))
self.fixed_supports = supports or []
receptive_field = 1 receptive_field = 1
self.supports_len = len(self.supports) if self.supports is not None else 0
# 如果使用自适应邻接矩阵,初始化相关参数 self.supports_len = len(self.fixed_supports)
if self.gcn_bool and self.addaptadj: if self.do_graph_conv and self.addaptadj:
aptinit = args.get("aptinit", None)
if aptinit is None: if aptinit is None:
if self.supports is None: nodevecs = torch.randn(args["num_nodes"], apt_size), torch.randn(apt_size, args["num_nodes"])
self.supports = []
self.nodevec1 = nn.Parameter(
torch.randn(args["num_nodes"], 10, device=args["device"])
)
self.nodevec2 = nn.Parameter(
torch.randn(10, args["num_nodes"], device=args["device"])
)
self.supports_len += 1
else: else:
if self.supports is None: nodevecs = self.svd_init(args["num_nodes"], apt_size, aptinit)
self.supports = [] self.supports_len += 1
m, p, n = torch.svd(aptinit) self.nodevec1, self.nodevec2 = [Parameter(n.to(args["device"]), requires_grad=True) for n in nodevecs]
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
self.nodevec1 = nn.Parameter(initemb1)
self.nodevec2 = nn.Parameter(initemb2)
self.supports_len += 1
# 获取模型参数 depth = list(range(self.blocks * self.layers))
ks, res, dil, skip, endc, out_dim = (
args["kernel_size"],
args["residual_channels"],
args["dilation_channels"],
args["skip_channels"],
args["end_channels"],
args["out_dim"],
)
# 构建模型层 # 1x1 convolution for residual and skip connections (slightly different see docstring)
self.residual_convs = ModuleList([Conv2d(dilation_channels, residual_channels, (1, 1)) for _ in depth])
self.skip_convs = ModuleList([Conv2d(dilation_channels, skip_channels, (1, 1)) for _ in depth])
self.bn = ModuleList([BatchNorm2d(residual_channels) for _ in depth])
self.graph_convs = ModuleList([GraphConvNet(dilation_channels, residual_channels, self.dropout, support_len=self.supports_len)
for _ in depth])
self.filter_convs = ModuleList()
self.gate_convs = ModuleList()
for b in range(self.blocks): for b in range(self.blocks):
add_scope, new_dil = ks - 1, 1 additional_scope = kernel_size - 1
D = 1 # dilation
for i in range(self.layers): for i in range(self.layers):
# 添加时间卷积层 # dilated convolutions
self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) self.filter_convs.append(Conv2d(residual_channels, dilation_channels, (1, kernel_size), dilation=D))
self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) self.gate_convs.append(Conv2d(residual_channels, dilation_channels, (1, kernel_size), dilation=D))
self.residual_convs.append(nn.Conv2d(dil, res, 1)) D *= 2
self.skip_convs.append(nn.Conv2d(dil, skip, 1)) receptive_field += additional_scope
self.bn.append(nn.BatchNorm2d(res)) additional_scope *= 2
new_dil *= 2
receptive_field += add_scope
add_scope *= 2
if self.gcn_bool:
self.gconv.append(
gcn(dil, res, args["dropout"], support_len=self.supports_len)
)
# 输出层
self.end_conv_1 = nn.Conv2d(skip, endc, 1)
self.end_conv_2 = nn.Conv2d(endc, out_dim, 1)
self.receptive_field = receptive_field self.receptive_field = receptive_field
self.end_conv_1 = Conv2d(skip_channels, end_channels, (1, 1), bias=True)
self.end_conv_2 = Conv2d(end_channels, out_dim, (1, 1), bias=True)
def forward(self, input): def forward(self, input):
""" x = input[..., 0:1].transpose(1, 3)
前向传播函数 # Input shape is (bs, features, n_nodes, n_timesteps)
实现模型的推理过程 in_len = x.size(3)
""" if in_len < self.receptive_field:
# 数据预处理 x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0))
input = input[..., 0:2].transpose(1, 3) if self.cat_feat_gc:
input = F.pad(input, (1, 0, 0, 0)) f1, f2 = x[:, [0]], x[:, 1:]
in_len = input.size(3) x1 = self.start_conv(f1)
x = ( x2 = F.leaky_relu(self.cat_feature_conv(f2))
F.pad(input, (self.receptive_field - in_len, 0, 0, 0)) x = x1 + x2
if in_len < self.receptive_field else:
else input x = self.start_conv(x)
) skip = 0
adjacency_matrices = self.fixed_supports
# 初始卷积 # calculate the current adaptive adj matrix once per iteration
x, skip, new_supports = self.start_conv(x), 0, None if self.addaptadj:
# 如果使用自适应邻接矩阵,计算新的邻接矩阵
if self.gcn_bool and self.addaptadj and self.supports is not None:
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1) adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
new_supports = self.supports + [adp] adjacency_matrices = self.fixed_supports + [adp]
# 主网络层的前向传播 # WaveNet layers
for i in range(self.blocks * self.layers): for i in range(self.blocks * self.layers):
residual = x residual = x
# 时间卷积操作 # dilated convolution
f = self.filter_convs[i](residual).tanh() filter = torch.tanh(self.filter_convs[i](residual))
g = self.gate_convs[i](residual).sigmoid() gate = torch.sigmoid(self.gate_convs[i](residual))
x = f * g x = filter * gate
s = self.skip_convs[i](x) # parametrized skip connection
skip = ( s = self.skip_convs[i](x) # what are we skipping??
skip[:, :, :, -s.size(3) :] if isinstance(skip, torch.Tensor) else 0 try: # if i > 0 this works
) + s skip = skip[:, :, :, -s.size(3):] # TODO(SS): Mean/Max Pool?
except:
skip = 0
skip = s + skip
if i == (self.blocks * self.layers - 1): # last X getting ignored anyway
break
# 图卷积操作 if self.do_graph_conv:
if self.gcn_bool and self.supports is not None: graph_out = self.graph_convs[i](x, adjacency_matrices)
x = self.gconv[i](x, new_supports if self.addaptadj else self.supports) x = x + graph_out if self.cat_feat_gc else graph_out
else: else:
x = self.residual_convs[i](x) x = self.residual_convs[i](x)
x = x + residual[:, :, :, -x.size(3) :] x = x + residual[:, :, :, -x.size(3):] # TODO(SS): Mean/Max Pool?
x = self.bn[i](x) x = self.bn[i](x)
# 输出层处理 x = F.relu(skip) # ignore last X?
return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip)))) x = F.relu(self.end_conv_1(x))
x = self.end_conv_2(x) # downsample to (bs, seq_length, 207, nfeatures)
# x = x.transpose(1, 3)
return x

View File

@ -1,97 +1,98 @@
import torch import torch, torch.nn as nn, torch.nn.functional as F
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import sys
class nconv(nn.Module): class nconv(nn.Module):
def __init__(self): """
super(nconv, self).__init__() 图卷积操作的实现类
使用einsum进行矩阵运算实现图卷积操作
"""
def forward(self, x, A): def forward(self, x, A):
x = torch.einsum("ncvl,vw->ncwl", (x, A)) return torch.einsum("ncvl,vw->ncwl", (x, A)).contiguous()
return x.contiguous()
class linear(nn.Module): class linear(nn.Module):
"""
线性变换层
使用1x1卷积实现线性变换
"""
def __init__(self, c_in, c_out): def __init__(self, c_in, c_out):
super(linear, self).__init__() super().__init__()
self.mlp = torch.nn.Conv2d( self.mlp = nn.Conv2d(c_in, c_out, 1)
c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True
)
def forward(self, x): def forward(self, x):
return self.mlp(x) return self.mlp(x)
class gcn(nn.Module): class gcn(nn.Module):
"""
图卷积网络层
实现高阶图卷积操作支持多阶邻接矩阵
"""
def __init__(self, c_in, c_out, dropout, support_len=3, order=2): def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
super(gcn, self).__init__() super().__init__()
self.nconv = nconv() self.nconv = nconv()
c_in = (order * support_len + 1) * c_in c_in = (order * support_len + 1) * c_in
self.mlp = linear(c_in, c_out) self.mlp, self.dropout, self.order = linear(c_in, c_out), dropout, order
self.dropout = dropout
self.order = order
def forward(self, x, support): def forward(self, x, support):
out = [x] out = [x]
for a in support: for a in support:
x1 = self.nconv(x, a) x1 = self.nconv(x, a)
out.append(x1) out.append(x1)
for k in range(2, self.order + 1): for _ in range(2, self.order + 1):
x2 = self.nconv(x1, a) x1 = self.nconv(x1, a)
out.append(x2) out.append(x1)
x1 = x2 return F.dropout(
self.mlp(torch.cat(out, dim=1)), self.dropout, training=self.training
h = torch.cat(out, dim=1) )
h = self.mlp(h)
h = F.dropout(h, self.dropout, training=self.training)
return h
class gwnet(nn.Module): class gwnet(nn.Module):
"""
Graph WaveNet模型的主类
结合了图卷积网络和时序卷积网络用于时空预测任务
"""
def __init__(self, args): def __init__(self, args):
super(gwnet, self).__init__() super().__init__()
self.dropout = args["dropout"] # 初始化基本参数
self.blocks = args["blocks"] self.dropout, self.blocks, self.layers = (
self.layers = args["layers"] args["dropout"],
self.gcn_bool = args["gcn_bool"] args["blocks"],
self.addaptadj = args["addaptadj"] args["layers"],
self.filter_convs = nn.ModuleList()
self.gate_convs = nn.ModuleList()
self.residual_convs = nn.ModuleList()
self.skip_convs = nn.ModuleList()
self.bn = nn.ModuleList()
self.gconv = nn.ModuleList()
self.start_conv = nn.Conv2d(
in_channels=args["in_dim"],
out_channels=args["residual_channels"],
kernel_size=(1, 1),
) )
self.gcn_bool, self.addaptadj = args["gcn_bool"], args["addaptadj"]
# 初始化各种卷积层和模块
self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList()
self.residual_convs, self.skip_convs, self.bn, self.gconv = (
nn.ModuleList(),
nn.ModuleList(),
nn.ModuleList(),
nn.ModuleList(),
)
self.start_conv = nn.Conv2d(args["in_dim"], args["residual_channels"], 1)
self.supports = args.get("supports", None) self.supports = args.get("supports", None)
# 计算感受野
receptive_field = 1 receptive_field = 1
self.supports_len = len(self.supports) if self.supports is not None else 0
self.supports_len = 0 # 如果使用自适应邻接矩阵,初始化相关参数
if self.supports is not None:
self.supports_len += len(self.supports)
if self.gcn_bool and self.addaptadj: if self.gcn_bool and self.addaptadj:
aptinit = args.get("aptinit", None) aptinit = args.get("aptinit", None)
if aptinit is None: if aptinit is None:
if self.supports is None: if self.supports is None:
self.supports = [] self.supports = []
self.nodevec1 = nn.Parameter( self.nodevec1 = nn.Parameter(
torch.randn(args["num_nodes"], 10).to(args["device"]), torch.randn(args["num_nodes"], 10, device=args["device"])
requires_grad=True, )
).to(args["device"])
self.nodevec2 = nn.Parameter( self.nodevec2 = nn.Parameter(
torch.randn(10, args["num_nodes"]).to(args["device"]), torch.randn(10, args["num_nodes"], device=args["device"])
requires_grad=True, )
).to(args["device"])
self.supports_len += 1 self.supports_len += 1
else: else:
if self.supports is None: if self.supports is None:
@ -99,156 +100,85 @@ class gwnet(nn.Module):
m, p, n = torch.svd(aptinit) m, p, n = torch.svd(aptinit)
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5)) initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t()) initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
self.nodevec1 = nn.Parameter(initemb1, requires_grad=True).to( self.nodevec1 = nn.Parameter(initemb1)
args["device"] self.nodevec2 = nn.Parameter(initemb2)
)
self.nodevec2 = nn.Parameter(initemb2, requires_grad=True).to(
args["device"]
)
self.supports_len += 1 self.supports_len += 1
kernel_size = args["kernel_size"] # 获取模型参数
residual_channels = args["residual_channels"] ks, res, dil, skip, endc, out_dim = (
dilation_channels = args["dilation_channels"] args["kernel_size"],
kernel_size = args["kernel_size"] args["residual_channels"],
skip_channels = args["skip_channels"] args["dilation_channels"],
end_channels = args["end_channels"] args["skip_channels"],
out_dim = args["out_dim"] args["end_channels"],
dropout = args["dropout"] args["out_dim"],
)
# 构建模型层
for b in range(self.blocks): for b in range(self.blocks):
additional_scope = kernel_size - 1 add_scope, new_dil = ks - 1, 1
new_dilation = 1
for i in range(self.layers): for i in range(self.layers):
# dilated convolutions # 添加时间卷积层
self.filter_convs.append( self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
nn.Conv2d( self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
in_channels=residual_channels, self.residual_convs.append(nn.Conv2d(dil, res, 1))
out_channels=dilation_channels, self.skip_convs.append(nn.Conv2d(dil, skip, 1))
kernel_size=(1, kernel_size), self.bn.append(nn.BatchNorm2d(res))
dilation=new_dilation, new_dil *= 2
) receptive_field += add_scope
) add_scope *= 2
self.gate_convs.append(
nn.Conv2d(
in_channels=residual_channels,
out_channels=dilation_channels,
kernel_size=(1, kernel_size),
dilation=new_dilation,
)
)
# 1x1 convolution for residual connection
self.residual_convs.append(
nn.Conv2d(
in_channels=dilation_channels,
out_channels=residual_channels,
kernel_size=(1, 1),
)
)
# 1x1 convolution for skip connection
self.skip_convs.append(
nn.Conv2d(
in_channels=dilation_channels,
out_channels=skip_channels,
kernel_size=(1, 1),
)
)
self.bn.append(nn.BatchNorm2d(residual_channels))
new_dilation *= 2
receptive_field += additional_scope
additional_scope *= 2
if self.gcn_bool: if self.gcn_bool:
self.gconv.append( self.gconv.append(
gcn( gcn(dil, res, args["dropout"], support_len=self.supports_len)
dilation_channels,
residual_channels,
dropout,
support_len=self.supports_len,
)
) )
self.end_conv_1 = nn.Conv2d( # 输出层
in_channels=skip_channels, self.end_conv_1 = nn.Conv2d(skip, endc, 1)
out_channels=end_channels, self.end_conv_2 = nn.Conv2d(endc, out_dim, 1)
kernel_size=(1, 1),
bias=True,
)
self.end_conv_2 = nn.Conv2d(
in_channels=end_channels,
out_channels=out_dim,
kernel_size=(1, 1),
bias=True,
)
self.receptive_field = receptive_field self.receptive_field = receptive_field
def forward(self, input): def forward(self, input):
input = input[..., 0:2] """
input = input.transpose(1, 3) 前向传播函数
input = nn.functional.pad(input, (1, 0, 0, 0)) 实现模型的推理过程
"""
# 数据预处理
input = input[..., 0:2].transpose(1, 3)
input = F.pad(input, (1, 0, 0, 0))
in_len = input.size(3) in_len = input.size(3)
if in_len < self.receptive_field: x = (
x = nn.functional.pad(input, (self.receptive_field - in_len, 0, 0, 0)) F.pad(input, (self.receptive_field - in_len, 0, 0, 0))
else: if in_len < self.receptive_field
x = input else input
x = self.start_conv(x) )
skip = 0
# calculate the current adaptive adj matrix once per iteration # 初始卷积
new_supports = None x, skip, new_supports = self.start_conv(x), 0, None
# 如果使用自适应邻接矩阵,计算新的邻接矩阵
if self.gcn_bool and self.addaptadj and self.supports is not None: if self.gcn_bool and self.addaptadj and self.supports is not None:
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1) adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
new_supports = self.supports + [adp] new_supports = self.supports + [adp]
# WaveNet layers # 主网络层的前向传播
for i in range(self.blocks * self.layers): for i in range(self.blocks * self.layers):
# |----------------------------------------| *residual*
# | |
# | |-- conv -- tanh --| |
# -> dilate -|----| * ----|-- 1x1 -- + --> *input*
# |-- conv -- sigm --| |
# 1x1
# |
# ---------------------------------------> + -------------> *skip*
# (dilation, init_dilation) = self.dilations[i]
# residual = dilation_func(x, dilation, init_dilation, i)
residual = x residual = x
# dilated convolution # 时间卷积操作
filter = self.filter_convs[i](residual) f = self.filter_convs[i](residual).tanh()
filter = torch.tanh(filter) g = self.gate_convs[i](residual).sigmoid()
gate = self.gate_convs[i](residual) x = f * g
gate = torch.sigmoid(gate) s = self.skip_convs[i](x)
x = filter * gate skip = (
skip[:, :, :, -s.size(3) :] if isinstance(skip, torch.Tensor) else 0
# parametrized skip connection ) + s
s = x
s = self.skip_convs[i](s)
try:
skip = skip[:, :, :, -s.size(3) :]
except:
skip = 0
skip = s + skip
# 图卷积操作
if self.gcn_bool and self.supports is not None: if self.gcn_bool and self.supports is not None:
if self.addaptadj: x = self.gconv[i](x, new_supports if self.addaptadj else self.supports)
x = self.gconv[i](x, new_supports)
else:
x = self.gconv[i](x, self.supports)
else: else:
x = self.residual_convs[i](x) x = self.residual_convs[i](x)
x = x + residual[:, :, :, -x.size(3) :] x = x + residual[:, :, :, -x.size(3) :]
x = self.bn[i](x) x = self.bn[i](x)
x = F.relu(skip) # 输出层处理
x = F.relu(self.end_conv_1(x)) return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip))))
x = self.end_conv_2(x)
return x

95
run_tests.sh Executable file
View File

@ -0,0 +1,95 @@
#!/bin/bash
# 设置默认模型名和数据集列表
MODEL_NAME="GWN"
DATASETS=(
"METR-LA"
"PEMS-BAY"
"NYCBike-InFlow"
"NYCBike-OutFlow"
"AirQuality"
"SolarEnergy"
)
# 初始化统计变量
success_count=0
failure_count=0
missing_count=0
total_count=0
success_datasets=()
failure_datasets=()
missing_datasets=()
# 检查是否有参数传入来覆盖默认值
if [ $# -gt 0 ]; then
MODEL_NAME=$1
# 如果传入了更多参数,使用它们作为数据集列表
if [ $# -gt 1 ]; then
DATASETS=(${@:2})
fi
fi
echo "使用模型: $MODEL_NAME"
echo "数据集列表: ${DATASETS[*]}"
echo "开始测试..."
echo ""
# 循环测试每个数据集
for dataset in "${DATASETS[@]}"; do
total_count=$((total_count + 1))
# 构建配置文件路径
CONFIG_PATH="config/${MODEL_NAME}/${dataset}.yaml"
echo "测试数据集: $dataset"
echo "使用配置文件: $CONFIG_PATH"
# 检查配置文件是否存在
if [ ! -f "$CONFIG_PATH" ]; then
echo "错误: 配置文件 $CONFIG_PATH 不存在!"
missing_count=$((missing_count + 1))
missing_datasets+=("$dataset")
echo "----------------------------------------"
continue
fi
# 执行测试命令并捕获输出
echo "执行: python run.py --config $CONFIG_PATH"
output=$(python run.py --config "$CONFIG_PATH" 2>&1)
# 如果没有找到明确的标记,回退到检查退出码
if [ $? -eq 0 ]; then
echo "数据集 $dataset 测试成功! (基于退出码)"
success_count=$((success_count + 1))
success_datasets+=("$dataset")
else
echo "数据集 $dataset 测试失败! (基于退出码)"
failure_count=$((failure_count + 1))
failure_datasets+=("$dataset")
fi
echo "----------------------------------------"
done
# 输出总结
echo "======================================="
echo "测试总结"
echo "======================================="
echo "总数据集数量: $total_count"
echo "成功数量: $success_count"
echo "失败数量: $failure_count"
echo "缺失配置文件数量: $missing_count"
if [ ${#success_datasets[@]} -gt 0 ]; then
echo "成功的数据集: ${success_datasets[*]}"
fi
if [ ${#failure_datasets[@]} -gt 0 ]; then
echo "失败的数据集: ${failure_datasets[*]}"
fi
if [ ${#missing_datasets[@]} -gt 0 ]; then
echo "缺失配置的数据集: ${missing_datasets[*]}"
fi
echo "======================================="
echo "所有测试完成!"

View File

@ -177,6 +177,14 @@ class Trainer:
# 前向传播 # 前向传播
label = target[..., : self.args["output_dim"]] label = target[..., : self.args["output_dim"]]
output = self.model(data).to(self.device) output = self.model(data).to(self.device)
# if output.shape != label.shape:
# import sys
# print(f"[Wrong]: Output shape: {output.shape}, Label shape: {label.shape}")
# sys.exit(1)
# else:
# import sys
# print(f"[Right]: Output shape: {output.shape}, Label shape: {label.shape}")
# sys.exit(0)
loss = self.loss(output, label) loss = self.loss(output, label)
# 反归一化 # 反归一化