From a9313390ac3062be453f52d416c74d33ca80fff4 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 3 Dec 2025 12:05:02 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=82=E9=85=8DGraphWaveNet?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/GWN/AirQuality.yaml | 23 +-- config/GWN/BJTaxi-InFlow.yaml | 12 +- config/GWN/BJTaxi-OutFlow.yaml | 9 +- config/GWN/METR-LA.yaml | 15 +- config/GWN/NYCBike-InFlow.yaml | 9 +- config/GWN/NYCBike-OutFlow.yaml | 15 +- config/GWN/PEMS-BAY.yaml | 61 +++++++ config/GWN/PEMSD3.yaml | 2 +- config/GWN/PEMSD4.yaml | 2 +- config/GWN/PEMSD7.yaml | 2 +- config/GWN/PEMSD8.yaml | 2 +- config/GWN/SolarEnergy.yaml | 11 +- config/tmp.py | 234 -------------------------- model/GWN/GraphWaveNet.py | 261 +++++++++++++--------------- model/GWN/GraphWaveNet_bk.py | 290 ++++++++++++-------------------- run_tests.sh | 95 +++++++++++ trainer/Trainer.py | 8 + 17 files changed, 448 insertions(+), 603 deletions(-) create mode 100644 config/GWN/PEMS-BAY.yaml delete mode 100644 config/tmp.py create mode 100755 run_tests.sh diff --git a/config/GWN/AirQuality.yaml b/config/GWN/AirQuality.yaml index 786219f..c8e57d5 100644 --- a/config/GWN/AirQuality.yaml +++ b/config/GWN/AirQuality.yaml @@ -6,40 +6,41 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 input_dim: 6 lag: 24 normalizer: std - num_nodes: 12 + num_nodes: 35 steps_per_day: 24 test_ratio: 0.2 val_ratio: 0.2 model: addaptadj: true + apt_size: 10 aptinit: null - batch_size: 16 + batch_size: 64 blocks: 4 dilation_channels: 32 dropout: 0.3 + do_graph_conv: True end_channels: 512 gcn_bool: true - in_dim: 2 - input_dim: 6 + in_dim: 1 + input_dim: 1 kernel_size: 2 - layers: 2 - num_nodes: 12 - out_dim: 12 - output_dim: 6 + layers: 4 + num_nodes: 35 + out_dim: 24 residual_channels: 32 skip_channels: 256 supports: null train: - batch_size: 16 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 @@ -54,7 +55,7 @@ train: mae_thresh: 0.0 mape_thresh: 0.0 max_grad_norm: 5 - output_dim: 6 + output_dim: 1 plot: false real_value: true weight_decay: 0 diff --git a/config/GWN/BJTaxi-InFlow.yaml b/config/GWN/BJTaxi-InFlow.yaml index f2f10c8..8f4de85 100644 --- a/config/GWN/BJTaxi-InFlow.yaml +++ b/config/GWN/BJTaxi-InFlow.yaml @@ -20,24 +20,26 @@ data: model: addaptadj: true + apt_size: 10 aptinit: null - batch_size: 32 + batch_size: 16 blocks: 4 dilation_channels: 32 dropout: 0.3 + do_graph_conv: True end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 1 input_dim: 1 kernel_size: 2 - layers: 2 + layers: 4 num_nodes: 1024 - out_dim: 12 - output_dim: 1 + out_dim: 24 residual_channels: 32 skip_channels: 256 supports: null + train: batch_size: 32 debug: false diff --git a/config/GWN/BJTaxi-OutFlow.yaml b/config/GWN/BJTaxi-OutFlow.yaml index cef9af4..f86270e 100644 --- a/config/GWN/BJTaxi-OutFlow.yaml +++ b/config/GWN/BJTaxi-OutFlow.yaml @@ -20,20 +20,21 @@ data: model: addaptadj: true + apt_size: 10 aptinit: null batch_size: 32 blocks: 4 dilation_channels: 32 dropout: 0.3 + do_graph_conv: True end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 1 input_dim: 1 kernel_size: 2 - layers: 2 + layers: 4 num_nodes: 1024 - out_dim: 12 - output_dim: 1 + out_dim: 24 residual_channels: 32 skip_channels: 256 supports: null diff --git a/config/GWN/METR-LA.yaml b/config/GWN/METR-LA.yaml index 9ffb5d1..ef38574 100644 --- a/config/GWN/METR-LA.yaml +++ b/config/GWN/METR-LA.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 64 column_wise: false days_per_week: 7 horizon: 24 @@ -20,26 +20,27 @@ data: model: addaptadj: true + apt_size: 10 aptinit: null - batch_size: 16 + batch_size: 64 blocks: 4 dilation_channels: 32 dropout: 0.3 + do_graph_conv: True end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 1 input_dim: 1 kernel_size: 2 - layers: 2 + layers: 4 num_nodes: 207 - out_dim: 12 - output_dim: 1 + out_dim: 24 residual_channels: 32 skip_channels: 256 supports: null train: - batch_size: 16 + batch_size: 64 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/GWN/NYCBike-InFlow.yaml b/config/GWN/NYCBike-InFlow.yaml index c536802..a85e36c 100644 --- a/config/GWN/NYCBike-InFlow.yaml +++ b/config/GWN/NYCBike-InFlow.yaml @@ -20,20 +20,21 @@ data: model: addaptadj: true + apt_size: 10 aptinit: null batch_size: 32 blocks: 4 dilation_channels: 32 dropout: 0.3 + do_graph_conv: True end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 1 input_dim: 1 kernel_size: 2 - layers: 2 + layers: 4 num_nodes: 128 - out_dim: 12 - output_dim: 1 + out_dim: 24 residual_channels: 32 skip_channels: 256 supports: null diff --git a/config/GWN/NYCBike-OutFlow.yaml b/config/GWN/NYCBike-OutFlow.yaml index c67790b..3ef3c8f 100644 --- a/config/GWN/NYCBike-OutFlow.yaml +++ b/config/GWN/NYCBike-OutFlow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -20,26 +20,27 @@ data: model: addaptadj: true + apt_size: 10 aptinit: null - batch_size: 32 + batch_size: 16 blocks: 4 dilation_channels: 32 dropout: 0.3 + do_graph_conv: True end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 1 input_dim: 1 kernel_size: 2 - layers: 2 + layers: 4 num_nodes: 128 - out_dim: 12 - output_dim: 1 + out_dim: 24 residual_channels: 32 skip_channels: 256 supports: null train: - batch_size: 32 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/GWN/PEMS-BAY.yaml b/config/GWN/PEMS-BAY.yaml new file mode 100644 index 0000000..3dc7acd --- /dev/null +++ b/config/GWN/PEMS-BAY.yaml @@ -0,0 +1,61 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: GWN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + addaptadj: true + apt_size: 10 + aptinit: null + batch_size: 64 + blocks: 4 + dilation_channels: 32 + dropout: 0.3 + do_graph_conv: True + end_channels: 512 + gcn_bool: true + in_dim: 1 + input_dim: 1 + kernel_size: 2 + layers: 4 + num_nodes: 325 + out_dim: 24 + residual_channels: 32 + skip_channels: 256 + supports: null + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 300 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: false + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: 0.0 + mape_thresh: 0.0 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 diff --git a/config/GWN/PEMSD3.yaml b/config/GWN/PEMSD3.yaml index 9e75da7..9194d3d 100755 --- a/config/GWN/PEMSD3.yaml +++ b/config/GWN/PEMSD3.yaml @@ -27,7 +27,7 @@ model: dropout: 0.3 end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 3 input_dim: 1 kernel_size: 2 layers: 2 diff --git a/config/GWN/PEMSD4.yaml b/config/GWN/PEMSD4.yaml index 5435727..ab6f18e 100755 --- a/config/GWN/PEMSD4.yaml +++ b/config/GWN/PEMSD4.yaml @@ -27,7 +27,7 @@ model: dropout: 0.3 end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 1 input_dim: 1 kernel_size: 2 layers: 2 diff --git a/config/GWN/PEMSD7.yaml b/config/GWN/PEMSD7.yaml index 7330998..4d82415 100755 --- a/config/GWN/PEMSD7.yaml +++ b/config/GWN/PEMSD7.yaml @@ -27,7 +27,7 @@ model: dropout: 0.3 end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 3 input_dim: 1 kernel_size: 2 layers: 2 diff --git a/config/GWN/PEMSD8.yaml b/config/GWN/PEMSD8.yaml index cebe500..26d0de8 100755 --- a/config/GWN/PEMSD8.yaml +++ b/config/GWN/PEMSD8.yaml @@ -27,7 +27,7 @@ model: dropout: 0.3 end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 3 input_dim: 1 kernel_size: 2 layers: 2 diff --git a/config/GWN/SolarEnergy.yaml b/config/GWN/SolarEnergy.yaml index afdce7a..cd1d043 100644 --- a/config/GWN/SolarEnergy.yaml +++ b/config/GWN/SolarEnergy.yaml @@ -20,20 +20,21 @@ data: model: addaptadj: true + apt_size: 10 aptinit: null - batch_size: 64 + batch_size: 32 blocks: 4 dilation_channels: 32 dropout: 0.3 + do_graph_conv: True end_channels: 512 gcn_bool: true - in_dim: 2 + in_dim: 1 input_dim: 1 kernel_size: 2 - layers: 2 + layers: 4 num_nodes: 137 - out_dim: 12 - output_dim: 1 + out_dim: 24 residual_channels: 32 skip_channels: 256 supports: null diff --git a/config/tmp.py b/config/tmp.py deleted file mode 100644 index 17cbe0b..0000000 --- a/config/tmp.py +++ /dev/null @@ -1,234 +0,0 @@ -#!/usr/bin/env python3 -import os -from collections import defaultdict -from ruamel.yaml import YAML -from ruamel.yaml.comments import CommentedMap - -yaml = YAML() -yaml.preserve_quotes = True -yaml.indent(mapping=2, sequence=4, offset=2) - -# 允许的 data keys -DATA_ALLOWED_KEYS = { - "lag", - "horizon", - "num_nodes", - "steps_per_day", - "days_per_week", - "test_ratio", - "val_ratio", - "batch_size", - "input_dim", - "column_wise", - "normalizer", -} - -# 全局默认值 -GLOBAL_DEFAULTS = { - "lag": 24, - "horizon": 24, - "num_nodes": 1, - "steps_per_day": 24, - "days_per_week": 7, - "test_ratio": 0.2, - "val_ratio": 0.2, - "batch_size": 16, - "input_dim": 1, - "column_wise": False, - "normalizer": "std", -} - -# train全局默认值 -GLOBAL_TRAIN_DEFAULTS = { - "output_dim": 1 -} - - -def load_yaml(path): - try: - with open(path, "r", encoding="utf-8") as f: - return yaml.load(f) - except Exception: - return None - - -def collect_dataset_defaults(base="."): - """ - 收集每个数据集 data 的 key 默认值,以及 train.output_dim 默认值 - """ - data_defaults = defaultdict(dict) - train_output_defaults = dict() - - for root, _, files in os.walk(base): - for name in files: - if not (name.endswith(".yaml") or name.endswith(".yml")): - continue - path = os.path.join(root, name) - cm = load_yaml(path) - if not isinstance(cm, CommentedMap): - continue - basic = cm.get("basic") - if not isinstance(basic, dict): - continue - dataset = basic.get("dataset") - if dataset is None: - continue - ds = str(dataset) - - # data 默认值 - data_sec = cm.get("data") - if isinstance(data_sec, dict): - for key in DATA_ALLOWED_KEYS: - if key not in data_defaults[ds] and key in data_sec and data_sec[key] is not None: - data_defaults[ds][key] = data_sec[key] - - # train.output_dim 默认值 - train_sec = cm.get("train") - if isinstance(train_sec, dict): - val = train_sec.get("output_dim") - if val is not None and ds not in train_output_defaults: - train_output_defaults[ds] = val - - return data_defaults, train_output_defaults - - -def ensure_basic_seed(cm: CommentedMap, path: str): - if "basic" not in cm or not isinstance(cm["basic"], dict): - cm["basic"] = CommentedMap() - basic = cm["basic"] - if "seed" not in basic: - basic["seed"] = 2023 - print(f"[ADD] {path}: basic.seed = 2023") - - -def fill_data_defaults(cm: CommentedMap, data_defaults: dict, path: str): - if "data" not in cm or not isinstance(cm["data"], dict): - cm["data"] = CommentedMap() - data_sec = cm["data"] - - basic = cm.get("basic", {}) - dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None - - for key in sorted(DATA_ALLOWED_KEYS): - if key in data_sec and data_sec[key] is not None: - continue - if dataset and dataset in data_defaults and key in data_defaults[dataset]: - chosen = data_defaults[dataset][key] - src = f"default_from_dataset[{dataset}]" - else: - chosen = GLOBAL_DEFAULTS[key] - src = "GLOBAL_DEFAULTS" - data_sec[key] = chosen - print(f"[FILL] {path}: data.{key} <- {src} ({repr(chosen)})") - - -def merge_test_log_into_train(cm: CommentedMap, path: str): - """ - 将 test 和 log 的 key 合并到 train,并删除 test 和 log - 同时确保 train.debug 存在 - """ - train_sec = cm.setdefault("train", CommentedMap()) - - for section in ["test", "log"]: - if section in cm and isinstance(cm[section], dict): - for k, v in cm[section].items(): - if k not in train_sec: - train_sec[k] = v - print(f"[MERGE] {path}: train.{k} <- {section}.{k} ({repr(v)})") - del cm[section] - print(f"[DEL] {path}: deleted section '{section}'") - - # train.debug - if "debug" not in train_sec: - train_sec["debug"] = False - print(f"[ADD] {path}: train.debug = False") - - -def fill_train_output_dim(cm: CommentedMap, train_output_defaults: dict, path: str): - train_sec = cm.setdefault("train", CommentedMap()) - if "output_dim" not in train_sec or train_sec["output_dim"] is None: - basic = cm.get("basic", {}) - dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None - if dataset and dataset in train_output_defaults: - val = train_output_defaults[dataset] - src = f"default_from_dataset[{dataset}]" - else: - val = GLOBAL_TRAIN_DEFAULTS["output_dim"] - src = "GLOBAL_TRAIN_DEFAULTS" - train_sec["output_dim"] = val - print(f"[FILL] {path}: train.output_dim <- {src} ({val})") - - -def sync_train_batch_size(cm: CommentedMap, path: str): - """ - 如果 train.batch_size 与 data.batch_size 不一致,以 data 为准 - """ - data_sec = cm.get("data", {}) - train_sec = cm.get("train", {}) - data_bs = data_sec.get("batch_size") - train_bs = train_sec.get("batch_size") - - if data_bs is not None and train_bs != data_bs: - train_sec["batch_size"] = data_bs - print(f"[SYNC] {path}: train.batch_size corrected to match data.batch_size ({data_bs})") - - -def sort_subkeys_and_insert_blanklines(cm: CommentedMap): - for sec in list(cm.keys()): - if isinstance(cm[sec], dict): - sorted_cm = CommentedMap() - for k in sorted(cm[sec].keys()): - sorted_cm[k] = cm[sec][k] - cm[sec] = sorted_cm - - keys = list(cm.keys()) - for i, k in enumerate(keys): - if i == 0: - try: - cm.yaml_set_comment_before_after_key(k, before=None) - except Exception: - pass - else: - try: - cm.yaml_set_comment_before_after_key(k, before="\n") - except Exception: - pass - - -def process_all(base="."): - print(">> Collecting dataset defaults ...") - data_defaults, train_output_defaults = collect_dataset_defaults(base) - print(">> Collected data defaults per dataset:") - for ds, kv in data_defaults.items(): - print(f" - {ds}: {kv}") - print(">> Collected train.output_dim defaults per dataset:") - for ds, val in train_output_defaults.items(): - print(f" - {ds}: output_dim = {val}") - - for root, _, files in os.walk(base): - for name in files: - if not (name.endswith(".yaml") or name.endswith(".yml")): - continue - path = os.path.join(root, name) - cm = load_yaml(path) - if not isinstance(cm, CommentedMap): - print(f"[SKIP] {path}: top-level not mapping or load failed") - continue - - ensure_basic_seed(cm, path) - fill_data_defaults(cm, data_defaults, path) - merge_test_log_into_train(cm, path) - fill_train_output_dim(cm, train_output_defaults, path) - sync_train_batch_size(cm, path) # <-- 新增逻辑 - sort_subkeys_and_insert_blanklines(cm) - - try: - with open(path, "w", encoding="utf-8") as f: - yaml.dump(cm, f) - print(f"[OK] Written: {path}") - except Exception as e: - print(f"[ERROR] Write failed {path}: {e}") - - -if __name__ == "__main__": - process_all(".") diff --git a/model/GWN/GraphWaveNet.py b/model/GWN/GraphWaveNet.py index 6f290f5..5bece37 100755 --- a/model/GWN/GraphWaveNet.py +++ b/model/GWN/GraphWaveNet.py @@ -1,53 +1,35 @@ -import torch, torch.nn as nn, torch.nn.functional as F +import torch +import torch.nn as nn +from torch.nn import BatchNorm2d, Conv1d, Conv2d, ModuleList, Parameter +import torch.nn.functional as F + +def nconv(x, A): + """Multiply x by adjacency matrix along source node axis""" + return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous() -class nconv(nn.Module): - """ - 图卷积操作的实现类 - 使用einsum进行矩阵运算,实现图卷积操作 - """ - - def forward(self, x, A): - return torch.einsum("ncvl,vw->ncwl", (x, A)).contiguous() - - -class linear(nn.Module): - """ - 线性变换层 - 使用1x1卷积实现线性变换 - """ - - def __init__(self, c_in, c_out): - super().__init__() - self.mlp = nn.Conv2d(c_in, c_out, 1) - - def forward(self, x): - return self.mlp(x) - - -class gcn(nn.Module): - """ - 图卷积网络层 - 实现高阶图卷积操作,支持多阶邻接矩阵 - """ - +class GraphConvNet(nn.Module): def __init__(self, c_in, c_out, dropout, support_len=3, order=2): super().__init__() - self.nconv = nconv() c_in = (order * support_len + 1) * c_in - self.mlp, self.dropout, self.order = linear(c_in, c_out), dropout, order + self.final_conv = Conv2d(c_in, c_out, (1, 1), padding=(0, 0), stride=(1, 1), bias=True) + self.dropout = dropout + self.order = order - def forward(self, x, support): + def forward(self, x, support: list): out = [x] for a in support: - x1 = self.nconv(x, a) + x1 = nconv(x, a) out.append(x1) - for _ in range(2, self.order + 1): - x1 = self.nconv(x1, a) - out.append(x1) - return F.dropout( - self.mlp(torch.cat(out, dim=1)), self.dropout, training=self.training - ) + for k in range(2, self.order + 1): + x2 = nconv(x1, a) + out.append(x2) + x1 = x2 + + h = torch.cat(out, dim=1) + h = self.final_conv(h) + h = F.dropout(h, self.dropout, training=self.training) + return h class gwnet(nn.Module): @@ -59,126 +41,121 @@ class gwnet(nn.Module): def __init__(self, args): super().__init__() # 初始化基本参数 - self.dropout, self.blocks, self.layers = ( - args["dropout"], - args["blocks"], - args["layers"], - ) - self.gcn_bool, self.addaptadj = args["gcn_bool"], args["addaptadj"] + self.dropout = args["dropout"] + self.blocks = args["blocks"] + self.layers = args["layers"] + self.do_graph_conv = args.get("do_graph_conv", True) + self.cat_feat_gc = args.get("cat_feat_gc", False) + self.addaptadj = args.get("addaptadj", True) + supports = None + aptinit = args.get("aptinit", None) + in_dim = args.get("in_dim") + out_dim = args.get("out_dim") + residual_channels = args.get("residual_channels") + dilation_channels = args.get("dilation_channels") + skip_channels = args.get("skip_channels") + end_channels = args.get("end_channels") + kernel_size = args.get("kernel_size") + apt_size = args.get("apt_size", 10) - # 初始化各种卷积层和模块 - self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList() - self.residual_convs, self.skip_convs, self.bn, self.gconv = ( - nn.ModuleList(), - nn.ModuleList(), - nn.ModuleList(), - nn.ModuleList(), - ) - self.start_conv = nn.Conv2d(args["in_dim"], args["residual_channels"], 1) - self.supports = args.get("supports", None) - # 计算感受野 + if self.cat_feat_gc: + self.start_conv = nn.Conv2d(in_channels=1, # hard code to avoid errors + out_channels=residual_channels, + kernel_size=(1, 1)) + self.cat_feature_conv = nn.Conv2d(in_channels=in_dim - 1, + out_channels=residual_channels, + kernel_size=(1, 1)) + else: + self.start_conv = nn.Conv2d(in_channels=in_dim, + out_channels=residual_channels, + kernel_size=(1, 1)) + + self.fixed_supports = supports or [] receptive_field = 1 - self.supports_len = len(self.supports) if self.supports is not None else 0 - # 如果使用自适应邻接矩阵,初始化相关参数 - if self.gcn_bool and self.addaptadj: - aptinit = args.get("aptinit", None) + self.supports_len = len(self.fixed_supports) + if self.do_graph_conv and self.addaptadj: if aptinit is None: - if self.supports is None: - self.supports = [] - self.nodevec1 = nn.Parameter( - torch.randn(args["num_nodes"], 10, device=args["device"]) - ) - self.nodevec2 = nn.Parameter( - torch.randn(10, args["num_nodes"], device=args["device"]) - ) - self.supports_len += 1 + nodevecs = torch.randn(args["num_nodes"], apt_size), torch.randn(apt_size, args["num_nodes"]) else: - if self.supports is None: - self.supports = [] - m, p, n = torch.svd(aptinit) - initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5)) - initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t()) - self.nodevec1 = nn.Parameter(initemb1) - self.nodevec2 = nn.Parameter(initemb2) - self.supports_len += 1 + nodevecs = self.svd_init(args["num_nodes"], apt_size, aptinit) + self.supports_len += 1 + self.nodevec1, self.nodevec2 = [Parameter(n.to(args["device"]), requires_grad=True) for n in nodevecs] - # 获取模型参数 - ks, res, dil, skip, endc, out_dim = ( - args["kernel_size"], - args["residual_channels"], - args["dilation_channels"], - args["skip_channels"], - args["end_channels"], - args["out_dim"], - ) + depth = list(range(self.blocks * self.layers)) - # 构建模型层 + # 1x1 convolution for residual and skip connections (slightly different see docstring) + self.residual_convs = ModuleList([Conv2d(dilation_channels, residual_channels, (1, 1)) for _ in depth]) + self.skip_convs = ModuleList([Conv2d(dilation_channels, skip_channels, (1, 1)) for _ in depth]) + self.bn = ModuleList([BatchNorm2d(residual_channels) for _ in depth]) + self.graph_convs = ModuleList([GraphConvNet(dilation_channels, residual_channels, self.dropout, support_len=self.supports_len) + for _ in depth]) + + self.filter_convs = ModuleList() + self.gate_convs = ModuleList() for b in range(self.blocks): - add_scope, new_dil = ks - 1, 1 + additional_scope = kernel_size - 1 + D = 1 # dilation for i in range(self.layers): - # 添加时间卷积层 - self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) - self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) - self.residual_convs.append(nn.Conv2d(dil, res, 1)) - self.skip_convs.append(nn.Conv2d(dil, skip, 1)) - self.bn.append(nn.BatchNorm2d(res)) - new_dil *= 2 - receptive_field += add_scope - add_scope *= 2 - if self.gcn_bool: - self.gconv.append( - gcn(dil, res, args["dropout"], support_len=self.supports_len) - ) - - # 输出层 - self.end_conv_1 = nn.Conv2d(skip, endc, 1) - self.end_conv_2 = nn.Conv2d(endc, out_dim, 1) + # dilated convolutions + self.filter_convs.append(Conv2d(residual_channels, dilation_channels, (1, kernel_size), dilation=D)) + self.gate_convs.append(Conv2d(residual_channels, dilation_channels, (1, kernel_size), dilation=D)) + D *= 2 + receptive_field += additional_scope + additional_scope *= 2 self.receptive_field = receptive_field + self.end_conv_1 = Conv2d(skip_channels, end_channels, (1, 1), bias=True) + self.end_conv_2 = Conv2d(end_channels, out_dim, (1, 1), bias=True) + def forward(self, input): - """ - 前向传播函数 - 实现模型的推理过程 - """ - # 数据预处理 - input = input[..., 0:2].transpose(1, 3) - input = F.pad(input, (1, 0, 0, 0)) - in_len = input.size(3) - x = ( - F.pad(input, (self.receptive_field - in_len, 0, 0, 0)) - if in_len < self.receptive_field - else input - ) - - # 初始卷积 - x, skip, new_supports = self.start_conv(x), 0, None - - # 如果使用自适应邻接矩阵,计算新的邻接矩阵 - if self.gcn_bool and self.addaptadj and self.supports is not None: + x = input[..., 0:1].transpose(1, 3) + # Input shape is (bs, features, n_nodes, n_timesteps) + in_len = x.size(3) + if in_len < self.receptive_field: + x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0)) + if self.cat_feat_gc: + f1, f2 = x[:, [0]], x[:, 1:] + x1 = self.start_conv(f1) + x2 = F.leaky_relu(self.cat_feature_conv(f2)) + x = x1 + x2 + else: + x = self.start_conv(x) + skip = 0 + adjacency_matrices = self.fixed_supports + # calculate the current adaptive adj matrix once per iteration + if self.addaptadj: adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1) - new_supports = self.supports + [adp] + adjacency_matrices = self.fixed_supports + [adp] - # 主网络层的前向传播 + # WaveNet layers for i in range(self.blocks * self.layers): residual = x - # 时间卷积操作 - f = self.filter_convs[i](residual).tanh() - g = self.gate_convs[i](residual).sigmoid() - x = f * g - s = self.skip_convs[i](x) - skip = ( - skip[:, :, :, -s.size(3) :] if isinstance(skip, torch.Tensor) else 0 - ) + s + # dilated convolution + filter = torch.tanh(self.filter_convs[i](residual)) + gate = torch.sigmoid(self.gate_convs[i](residual)) + x = filter * gate + # parametrized skip connection + s = self.skip_convs[i](x) # what are we skipping?? + try: # if i > 0 this works + skip = skip[:, :, :, -s.size(3):] # TODO(SS): Mean/Max Pool? + except: + skip = 0 + skip = s + skip + if i == (self.blocks * self.layers - 1): # last X getting ignored anyway + break - # 图卷积操作 - if self.gcn_bool and self.supports is not None: - x = self.gconv[i](x, new_supports if self.addaptadj else self.supports) + if self.do_graph_conv: + graph_out = self.graph_convs[i](x, adjacency_matrices) + x = x + graph_out if self.cat_feat_gc else graph_out else: x = self.residual_convs[i](x) - x = x + residual[:, :, :, -x.size(3) :] + x = x + residual[:, :, :, -x.size(3):] # TODO(SS): Mean/Max Pool? x = self.bn[i](x) - # 输出层处理 - return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip)))) + x = F.relu(skip) # ignore last X? + x = F.relu(self.end_conv_1(x)) + x = self.end_conv_2(x) # downsample to (bs, seq_length, 207, nfeatures) + # x = x.transpose(1, 3) + return x diff --git a/model/GWN/GraphWaveNet_bk.py b/model/GWN/GraphWaveNet_bk.py index 19308d4..6f290f5 100755 --- a/model/GWN/GraphWaveNet_bk.py +++ b/model/GWN/GraphWaveNet_bk.py @@ -1,97 +1,98 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import Variable -import sys +import torch, torch.nn as nn, torch.nn.functional as F class nconv(nn.Module): - def __init__(self): - super(nconv, self).__init__() + """ + 图卷积操作的实现类 + 使用einsum进行矩阵运算,实现图卷积操作 + """ def forward(self, x, A): - x = torch.einsum("ncvl,vw->ncwl", (x, A)) - return x.contiguous() + return torch.einsum("ncvl,vw->ncwl", (x, A)).contiguous() class linear(nn.Module): + """ + 线性变换层 + 使用1x1卷积实现线性变换 + """ + def __init__(self, c_in, c_out): - super(linear, self).__init__() - self.mlp = torch.nn.Conv2d( - c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True - ) + super().__init__() + self.mlp = nn.Conv2d(c_in, c_out, 1) def forward(self, x): return self.mlp(x) class gcn(nn.Module): + """ + 图卷积网络层 + 实现高阶图卷积操作,支持多阶邻接矩阵 + """ + def __init__(self, c_in, c_out, dropout, support_len=3, order=2): - super(gcn, self).__init__() + super().__init__() self.nconv = nconv() c_in = (order * support_len + 1) * c_in - self.mlp = linear(c_in, c_out) - self.dropout = dropout - self.order = order + self.mlp, self.dropout, self.order = linear(c_in, c_out), dropout, order def forward(self, x, support): out = [x] for a in support: x1 = self.nconv(x, a) out.append(x1) - for k in range(2, self.order + 1): - x2 = self.nconv(x1, a) - out.append(x2) - x1 = x2 - - h = torch.cat(out, dim=1) - h = self.mlp(h) - h = F.dropout(h, self.dropout, training=self.training) - return h + for _ in range(2, self.order + 1): + x1 = self.nconv(x1, a) + out.append(x1) + return F.dropout( + self.mlp(torch.cat(out, dim=1)), self.dropout, training=self.training + ) class gwnet(nn.Module): + """ + Graph WaveNet模型的主类 + 结合了图卷积网络和时序卷积网络,用于时空预测任务 + """ + def __init__(self, args): - super(gwnet, self).__init__() - self.dropout = args["dropout"] - self.blocks = args["blocks"] - self.layers = args["layers"] - self.gcn_bool = args["gcn_bool"] - self.addaptadj = args["addaptadj"] - - self.filter_convs = nn.ModuleList() - self.gate_convs = nn.ModuleList() - self.residual_convs = nn.ModuleList() - self.skip_convs = nn.ModuleList() - self.bn = nn.ModuleList() - self.gconv = nn.ModuleList() - - self.start_conv = nn.Conv2d( - in_channels=args["in_dim"], - out_channels=args["residual_channels"], - kernel_size=(1, 1), + super().__init__() + # 初始化基本参数 + self.dropout, self.blocks, self.layers = ( + args["dropout"], + args["blocks"], + args["layers"], ) + self.gcn_bool, self.addaptadj = args["gcn_bool"], args["addaptadj"] + + # 初始化各种卷积层和模块 + self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList() + self.residual_convs, self.skip_convs, self.bn, self.gconv = ( + nn.ModuleList(), + nn.ModuleList(), + nn.ModuleList(), + nn.ModuleList(), + ) + self.start_conv = nn.Conv2d(args["in_dim"], args["residual_channels"], 1) self.supports = args.get("supports", None) + # 计算感受野 receptive_field = 1 + self.supports_len = len(self.supports) if self.supports is not None else 0 - self.supports_len = 0 - if self.supports is not None: - self.supports_len += len(self.supports) - + # 如果使用自适应邻接矩阵,初始化相关参数 if self.gcn_bool and self.addaptadj: aptinit = args.get("aptinit", None) if aptinit is None: if self.supports is None: self.supports = [] self.nodevec1 = nn.Parameter( - torch.randn(args["num_nodes"], 10).to(args["device"]), - requires_grad=True, - ).to(args["device"]) + torch.randn(args["num_nodes"], 10, device=args["device"]) + ) self.nodevec2 = nn.Parameter( - torch.randn(10, args["num_nodes"]).to(args["device"]), - requires_grad=True, - ).to(args["device"]) + torch.randn(10, args["num_nodes"], device=args["device"]) + ) self.supports_len += 1 else: if self.supports is None: @@ -99,156 +100,85 @@ class gwnet(nn.Module): m, p, n = torch.svd(aptinit) initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5)) initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t()) - self.nodevec1 = nn.Parameter(initemb1, requires_grad=True).to( - args["device"] - ) - self.nodevec2 = nn.Parameter(initemb2, requires_grad=True).to( - args["device"] - ) + self.nodevec1 = nn.Parameter(initemb1) + self.nodevec2 = nn.Parameter(initemb2) self.supports_len += 1 - kernel_size = args["kernel_size"] - residual_channels = args["residual_channels"] - dilation_channels = args["dilation_channels"] - kernel_size = args["kernel_size"] - skip_channels = args["skip_channels"] - end_channels = args["end_channels"] - out_dim = args["out_dim"] - dropout = args["dropout"] + # 获取模型参数 + ks, res, dil, skip, endc, out_dim = ( + args["kernel_size"], + args["residual_channels"], + args["dilation_channels"], + args["skip_channels"], + args["end_channels"], + args["out_dim"], + ) + # 构建模型层 for b in range(self.blocks): - additional_scope = kernel_size - 1 - new_dilation = 1 + add_scope, new_dil = ks - 1, 1 for i in range(self.layers): - # dilated convolutions - self.filter_convs.append( - nn.Conv2d( - in_channels=residual_channels, - out_channels=dilation_channels, - kernel_size=(1, kernel_size), - dilation=new_dilation, - ) - ) - - self.gate_convs.append( - nn.Conv2d( - in_channels=residual_channels, - out_channels=dilation_channels, - kernel_size=(1, kernel_size), - dilation=new_dilation, - ) - ) - - # 1x1 convolution for residual connection - self.residual_convs.append( - nn.Conv2d( - in_channels=dilation_channels, - out_channels=residual_channels, - kernel_size=(1, 1), - ) - ) - - # 1x1 convolution for skip connection - self.skip_convs.append( - nn.Conv2d( - in_channels=dilation_channels, - out_channels=skip_channels, - kernel_size=(1, 1), - ) - ) - self.bn.append(nn.BatchNorm2d(residual_channels)) - new_dilation *= 2 - receptive_field += additional_scope - additional_scope *= 2 + # 添加时间卷积层 + self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) + self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) + self.residual_convs.append(nn.Conv2d(dil, res, 1)) + self.skip_convs.append(nn.Conv2d(dil, skip, 1)) + self.bn.append(nn.BatchNorm2d(res)) + new_dil *= 2 + receptive_field += add_scope + add_scope *= 2 if self.gcn_bool: self.gconv.append( - gcn( - dilation_channels, - residual_channels, - dropout, - support_len=self.supports_len, - ) + gcn(dil, res, args["dropout"], support_len=self.supports_len) ) - self.end_conv_1 = nn.Conv2d( - in_channels=skip_channels, - out_channels=end_channels, - kernel_size=(1, 1), - bias=True, - ) - - self.end_conv_2 = nn.Conv2d( - in_channels=end_channels, - out_channels=out_dim, - kernel_size=(1, 1), - bias=True, - ) - + # 输出层 + self.end_conv_1 = nn.Conv2d(skip, endc, 1) + self.end_conv_2 = nn.Conv2d(endc, out_dim, 1) self.receptive_field = receptive_field def forward(self, input): - input = input[..., 0:2] - input = input.transpose(1, 3) - input = nn.functional.pad(input, (1, 0, 0, 0)) + """ + 前向传播函数 + 实现模型的推理过程 + """ + # 数据预处理 + input = input[..., 0:2].transpose(1, 3) + input = F.pad(input, (1, 0, 0, 0)) in_len = input.size(3) - if in_len < self.receptive_field: - x = nn.functional.pad(input, (self.receptive_field - in_len, 0, 0, 0)) - else: - x = input - x = self.start_conv(x) - skip = 0 + x = ( + F.pad(input, (self.receptive_field - in_len, 0, 0, 0)) + if in_len < self.receptive_field + else input + ) - # calculate the current adaptive adj matrix once per iteration - new_supports = None + # 初始卷积 + x, skip, new_supports = self.start_conv(x), 0, None + + # 如果使用自适应邻接矩阵,计算新的邻接矩阵 if self.gcn_bool and self.addaptadj and self.supports is not None: adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1) new_supports = self.supports + [adp] - # WaveNet layers + # 主网络层的前向传播 for i in range(self.blocks * self.layers): - # |----------------------------------------| *residual* - # | | - # | |-- conv -- tanh --| | - # -> dilate -|----| * ----|-- 1x1 -- + --> *input* - # |-- conv -- sigm --| | - # 1x1 - # | - # ---------------------------------------> + -------------> *skip* - - # (dilation, init_dilation) = self.dilations[i] - - # residual = dilation_func(x, dilation, init_dilation, i) residual = x - # dilated convolution - filter = self.filter_convs[i](residual) - filter = torch.tanh(filter) - gate = self.gate_convs[i](residual) - gate = torch.sigmoid(gate) - x = filter * gate - - # parametrized skip connection - - s = x - s = self.skip_convs[i](s) - try: - skip = skip[:, :, :, -s.size(3) :] - except: - skip = 0 - skip = s + skip + # 时间卷积操作 + f = self.filter_convs[i](residual).tanh() + g = self.gate_convs[i](residual).sigmoid() + x = f * g + s = self.skip_convs[i](x) + skip = ( + skip[:, :, :, -s.size(3) :] if isinstance(skip, torch.Tensor) else 0 + ) + s + # 图卷积操作 if self.gcn_bool and self.supports is not None: - if self.addaptadj: - x = self.gconv[i](x, new_supports) - else: - x = self.gconv[i](x, self.supports) + x = self.gconv[i](x, new_supports if self.addaptadj else self.supports) else: x = self.residual_convs[i](x) - x = x + residual[:, :, :, -x.size(3) :] - x = self.bn[i](x) - x = F.relu(skip) - x = F.relu(self.end_conv_1(x)) - x = self.end_conv_2(x) - return x + # 输出层处理 + return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip)))) diff --git a/run_tests.sh b/run_tests.sh new file mode 100755 index 0000000..a27a3bf --- /dev/null +++ b/run_tests.sh @@ -0,0 +1,95 @@ +#!/bin/bash + +# 设置默认模型名和数据集列表 +MODEL_NAME="GWN" +DATASETS=( + "METR-LA" + "PEMS-BAY" + "NYCBike-InFlow" + "NYCBike-OutFlow" + "AirQuality" + "SolarEnergy" +) + +# 初始化统计变量 +success_count=0 +failure_count=0 +missing_count=0 +total_count=0 +success_datasets=() +failure_datasets=() +missing_datasets=() + +# 检查是否有参数传入来覆盖默认值 +if [ $# -gt 0 ]; then + MODEL_NAME=$1 + # 如果传入了更多参数,使用它们作为数据集列表 + if [ $# -gt 1 ]; then + DATASETS=(${@:2}) + fi +fi + +echo "使用模型: $MODEL_NAME" +echo "数据集列表: ${DATASETS[*]}" +echo "开始测试..." +echo "" + +# 循环测试每个数据集 +for dataset in "${DATASETS[@]}"; do + total_count=$((total_count + 1)) + # 构建配置文件路径 + CONFIG_PATH="config/${MODEL_NAME}/${dataset}.yaml" + + echo "测试数据集: $dataset" + echo "使用配置文件: $CONFIG_PATH" + + # 检查配置文件是否存在 + if [ ! -f "$CONFIG_PATH" ]; then + echo "错误: 配置文件 $CONFIG_PATH 不存在!" + missing_count=$((missing_count + 1)) + missing_datasets+=("$dataset") + echo "----------------------------------------" + continue + fi + + # 执行测试命令并捕获输出 + echo "执行: python run.py --config $CONFIG_PATH" + output=$(python run.py --config "$CONFIG_PATH" 2>&1) + + # 如果没有找到明确的标记,回退到检查退出码 + if [ $? -eq 0 ]; then + echo "数据集 $dataset 测试成功! (基于退出码)" + success_count=$((success_count + 1)) + success_datasets+=("$dataset") + else + echo "数据集 $dataset 测试失败! (基于退出码)" + failure_count=$((failure_count + 1)) + failure_datasets+=("$dataset") + fi + + echo "----------------------------------------" +done + +# 输出总结 +echo "=======================================" +echo "测试总结" +echo "=======================================" +echo "总数据集数量: $total_count" +echo "成功数量: $success_count" +echo "失败数量: $failure_count" +echo "缺失配置文件数量: $missing_count" + +if [ ${#success_datasets[@]} -gt 0 ]; then + echo "成功的数据集: ${success_datasets[*]}" +fi + +if [ ${#failure_datasets[@]} -gt 0 ]; then + echo "失败的数据集: ${failure_datasets[*]}" +fi + +if [ ${#missing_datasets[@]} -gt 0 ]; then + echo "缺失配置的数据集: ${missing_datasets[*]}" +fi + +echo "=======================================" +echo "所有测试完成!" \ No newline at end of file diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 85060f1..2bd7e6e 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -177,6 +177,14 @@ class Trainer: # 前向传播 label = target[..., : self.args["output_dim"]] output = self.model(data).to(self.device) + # if output.shape != label.shape: + # import sys + # print(f"[Wrong]: Output shape: {output.shape}, Label shape: {label.shape}") + # sys.exit(1) + # else: + # import sys + # print(f"[Right]: Output shape: {output.shape}, Label shape: {label.shape}") + # sys.exit(0) loss = self.loss(output, label) # 反归一化