From d8f4cc5825647c26f819040b5edbe4f99cf40ab0 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 10 Dec 2025 21:08:20 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AE=80=E5=8C=96trainer=EF=BC=8C=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E8=AE=BE=E5=A4=87bug=EF=BC=8C=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E6=89=B9=E9=87=8F=E8=BF=90=E8=A1=8C=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/args_parser.py | 36 +- generate_launch_configs.py | 134 - mypy.ini | 4 - run_tests.sh | 95 - test_results.txt | 5406 ------------------------------------ train.py | 63 + trainer/Trainer.py | 250 +- trainer/Trainer_bk.py | 420 +++ trainer/Trainer_old.py | 229 -- utils/initializer.py | 6 +- 10 files changed, 550 insertions(+), 6093 deletions(-) delete mode 100644 generate_launch_configs.py delete mode 100644 mypy.ini delete mode 100755 run_tests.sh delete mode 100644 test_results.txt create mode 100644 train.py create mode 100755 trainer/Trainer_bk.py delete mode 100755 trainer/Trainer_old.py diff --git a/config/args_parser.py b/config/args_parser.py index ebd7bda..256c1f7 100755 --- a/config/args_parser.py +++ b/config/args_parser.py @@ -15,39 +15,5 @@ def parse_args(): config = yaml.safe_load(file) else: raise ValueError("Configuration file path must be provided using --config") - - # Update configuration with command-line arguments - # Merge 'basic' configuration into the root dictionary - # config.update(config.get('basic', {})) - - # Add adaptive configuration based on external commands - if "data" in config and "type" in config["data"]: - config["data"]["type"] = config["basic"].get("dataset", config["data"]["type"]) - if "model" in config and "type" in config["model"]: - config["model"]["type"] = config["basic"].get("model", config["model"]["type"]) - if "model" in config and "rnn_units" in config["model"]: - config["model"]["rnn_units"] = config["basic"].get( - "rnn", config["model"]["rnn_units"] - ) - if "model" in config and "embed_dim" in config["model"]: - config["model"]["embed_dim"] = config["basic"].get( - "emb", config["model"]["embed_dim"] - ) - if "data" in config and "sample" in config["data"]: - config["data"]["sample"] = config["basic"].get( - "sample", config["data"]["sample"] - ) - if "train" in config and "device" in config["train"]: - config["train"]["device"] = config["basic"].get( - "device", config["train"]["device"] - ) - if "train" in config and "debug" in config["train"]: - config["train"]["debug"] = config["basic"].get( - "debug", config["train"]["debug"] - ) - if "cuda" in config: - config["cuda"] = config["basic"].get("cuda", config["cuda"]) - if "mode" in config: - config["mode"] = config["basic"].get("mode", config["mode"]) - + return config diff --git a/generate_launch_configs.py b/generate_launch_configs.py deleted file mode 100644 index 6477e16..0000000 --- a/generate_launch_configs.py +++ /dev/null @@ -1,134 +0,0 @@ -import os -import re - -# 配置路径 -CONFIG_DIR = "/user/czzhangheng/code/TrafficWheel/config" -LAUNCH_FILE = "/user/czzhangheng/code/TrafficWheel/.vscode/launch.json" - -# 遍历所有yaml文件 -def find_all_yaml_files(directory): - yaml_files = [] - for root, dirs, files in os.walk(directory): - for file in files: - if file.endswith(".yaml") and not file.startswith("BJTaxi"): - yaml_files.append(os.path.join(root, file)) - return yaml_files - -# 生成launch配置字符串 -def generate_launch_config_string(yaml_files): - config_strings = [] - - for file_path in yaml_files: - # 提取模型名和数据集名 - relative_path = os.path.relpath(file_path, CONFIG_DIR) - model_name = relative_path.split(os.sep)[0] - dataset_name = os.path.splitext(os.path.basename(file_path))[0] - - # 处理v2版本 - if "v2_" in dataset_name: - model_display_name = f"{model_name}_v2" - dataset_display_name = dataset_name.replace("v2_", "") - else: - model_display_name = model_name - dataset_display_name = dataset_name - - # 生成配置字符串 - config_string = f''' - {{ - "name": "{model_display_name}: {dataset_display_name}", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/{model_name}/{os.path.basename(file_path)}" - }}''' - - config_strings.append(config_string) - - return ",".join(config_strings) - -# 读取现有的launch.json文件,提取配置名称 -def get_existing_config_names(): - with open(LAUNCH_FILE, 'r') as f: - content = f.read() - - # 提取所有配置名称 - name_pattern = re.compile(r'"name"\s*:\s*"([^"]+)"') - matches = name_pattern.findall(content) - - return set(matches) - -# 生成新的配置,过滤掉已存在的 -def generate_new_configs(yaml_files, existing_names): - new_configs = [] - - for file_path in yaml_files: - # 提取模型名和数据集名 - relative_path = os.path.relpath(file_path, CONFIG_DIR) - model_name = relative_path.split(os.sep)[0] - dataset_name = os.path.splitext(os.path.basename(file_path))[0] - - # 处理v2版本 - if "v2_" in dataset_name: - model_display_name = f"{model_name}_v2" - dataset_display_name = dataset_name.replace("v2_", "") - else: - model_display_name = model_name - dataset_display_name = dataset_name - - # 生成配置名称 - config_name = f"{model_display_name}: {dataset_display_name}" - - # 如果配置不存在,则添加 - if config_name not in existing_names: - new_configs.append(file_path) - - return new_configs - -# 更新launch.json文件 -def update_launch_json(new_configs_string): - with open(LAUNCH_FILE, 'r') as f: - content = f.read() - - # 找到configurations数组的结束位置 - configs_end_match = re.search(r'\s*\]\s*\}', content) - if not configs_end_match: - return False - - # 插入新的配置 - insert_pos = configs_end_match.start() - new_content = content[:insert_pos] + new_configs_string + content[insert_pos:] - - # 保存文件 - with open(LAUNCH_FILE, 'w') as f: - f.write(new_content) - - return True - -# 主函数 -def main(): - # 查找所有yaml文件 - yaml_files = find_all_yaml_files(CONFIG_DIR) - - # 获取现有配置名称 - existing_names = get_existing_config_names() - - # 生成新的配置,过滤掉已存在的 - new_config_files = generate_new_configs(yaml_files, existing_names) - - if not new_config_files: - print("No new configurations to add") - return - - # 生成新的配置字符串 - new_configs_string = generate_launch_config_string(new_config_files) - - # 更新launch.json文件 - if update_launch_json(new_configs_string): - print(f"Added {len(new_config_files)} new launch configurations") - print(f"Total configurations: {len(existing_names) + len(new_config_files)}") - else: - print("Failed to update launch.json") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index c77f418..0000000 --- a/mypy.ini +++ /dev/null @@ -1,4 +0,0 @@ -[mypy] -explicit_package_bases = True -ignore_missing_imports = True -no_site_packages = True diff --git a/run_tests.sh b/run_tests.sh deleted file mode 100755 index ef700a1..0000000 --- a/run_tests.sh +++ /dev/null @@ -1,95 +0,0 @@ -#!/bin/bash - -# 设置默认模型名和数据集列表 -MODEL_NAME="STAEFormer" -DATASETS=( - "METR-LA" - "PEMS-BAY" - "NYCBike-InFlow" - "NYCBike-OutFlow" - "AirQuality" - "SolarEnergy" -) - -# 初始化统计变量 -success_count=0 -failure_count=0 -missing_count=0 -total_count=0 -success_datasets=() -failure_datasets=() -missing_datasets=() - -# 检查是否有参数传入来覆盖默认值 -if [ $# -gt 0 ]; then - MODEL_NAME=$1 - # 如果传入了更多参数,使用它们作为数据集列表 - if [ $# -gt 1 ]; then - DATASETS=(${@:2}) - fi -fi - -echo "使用模型: $MODEL_NAME" -echo "数据集列表: ${DATASETS[*]}" -echo "开始测试..." -echo "" - -# 循环测试每个数据集 -for dataset in "${DATASETS[@]}"; do - total_count=$((total_count + 1)) - # 构建配置文件路径 - CONFIG_PATH="config/${MODEL_NAME}/${dataset}.yaml" - - echo "测试数据集: $dataset" - echo "使用配置文件: $CONFIG_PATH" - - # 检查配置文件是否存在 - if [ ! -f "$CONFIG_PATH" ]; then - echo "错误: 配置文件 $CONFIG_PATH 不存在!" - missing_count=$((missing_count + 1)) - missing_datasets+=("$dataset") - echo "----------------------------------------" - continue - fi - - # 执行测试命令,同时捕获输出并显示在控制台上 - echo "执行: python run.py --config $CONFIG_PATH" - output=$(python run.py --config "$CONFIG_PATH" 2>&1 | tee /dev/tty) - - # 如果没有找到明确的标记,回退到检查退出码 - if [ $? -eq 0 ]; then - echo "数据集 $dataset 测试成功! (基于退出码)" - success_count=$((success_count + 1)) - success_datasets+=("$dataset") - else - echo "数据集 $dataset 测试失败! (基于退出码)" - failure_count=$((failure_count + 1)) - failure_datasets+=("$dataset") - fi - - echo "----------------------------------------" -done - -# 输出总结 -echo "=======================================" -echo "测试总结" -echo "=======================================" -echo "总数据集数量: $total_count" -echo "成功数量: $success_count" -echo "失败数量: $failure_count" -echo "缺失配置文件数量: $missing_count" - -if [ ${#success_datasets[@]} -gt 0 ]; then - echo "成功的数据集: ${success_datasets[*]}" -fi - -if [ ${#failure_datasets[@]} -gt 0 ]; then - echo "失败的数据集: ${failure_datasets[*]}" -fi - -if [ ${#missing_datasets[@]} -gt 0 ]; then - echo "缺失配置的数据集: ${missing_datasets[*]}" -fi - -echo "=======================================" -echo "所有测试完成!" \ No newline at end of file diff --git a/test_results.txt b/test_results.txt deleted file mode 100644 index 6116217..0000000 --- a/test_results.txt +++ /dev/null @@ -1,5406 +0,0 @@ -# 测试报告 - -## 测试概述 -- 测试时间: 2025-12-01 22:20:35 -- 总测试文件数: 252 -- 通过: 41 -- 失败: 0 -- 错误: 211 - -## 通过的配置文件 -- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike_Inflow.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/AirQuality.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/PEMS-BAY.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike_Outflow.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STID/SolarEnergy.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD4.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD8.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD3.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD7.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-inflow.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/AirQuality.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/PEMS-BAY.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/SolarEnergy.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/v2_METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-outflow.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/AEPSA/v2_SolarEnergy.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/NYCBike-InFlow.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD4.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/NYCBike-OutFlow.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD8.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD3.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/SolarEnergy.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD7.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD8.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD4.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/DSANET/METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD8.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD3.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD7.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-inflow.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/BeijingAirQuality(Deprecated).yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/METR-LA.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/AirQuality.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/PEMS-BAY.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/SolarEnergy.yaml -- ✅ /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-outflow.yaml - -## 失败的配置文件 - -## 出错的配置文件 -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXPB/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD7(L).yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/Hainan.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TWDGCN/PEMSD7(M).yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STID/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STAEFormer/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/TCN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/SD.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7(L).yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/Hainan.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7(M).yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DSANET/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DSANET/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DSANET/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/DSANET/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/REPST/PEMS-BAY_paper.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7(L).yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/Hainan.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7(M).yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD7.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/NYCBike-InFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD4.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/METR-LA.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/AirQuality.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/NYCBike-OutFlow.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD8.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD3.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/SolarEnergy.yaml -- ⚠️ /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD7.yaml - -## 详细输出 - -### PASSED - -#### /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike_Inflow.yaml - -``` -模型参数量: 118040 -加载 NYCBike-InFlow 数据集中... -✓ Test passed: output shape torch.Size([64, 24, 128, 1]) matches label shape torch.Size([64, 24, 128, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/METR-LA.yaml - -``` -模型参数量: 120568 -加载 METR-LA 数据集中... -✓ Test passed: output shape torch.Size([64, 24, 207, 1]) matches label shape torch.Size([64, 24, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/AirQuality.yaml - -``` -模型参数量: 115064 -加载 AirQuality 数据集中... -✓ Test passed: output shape torch.Size([64, 24, 35, 1]) matches label shape torch.Size([64, 24, 35, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/PEMS-BAY.yaml - -``` -模型参数量: 124344 -加载 PEMS-BAY 数据集中... -✓ Test passed: output shape torch.Size([64, 24, 325, 1]) matches label shape torch.Size([64, 24, 325, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike_Outflow.yaml - -``` -模型参数量: 118040 -加载 NYCBike-OutFlow 数据集中... -✓ Test passed: output shape torch.Size([64, 24, 128, 1]) matches label shape torch.Size([64, 24, 128, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/SolarEnergy.yaml - -``` -模型参数量: 118328 -加载 SolarEnergy 数据集中... -✓ Test passed: output shape torch.Size([64, 24, 137, 1]) matches label shape torch.Size([64, 24, 137, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD4.yaml - -``` -模型参数量: 1354932 -加载 PEMSD4 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/PEMSD4/2025-12-01_21-52-10/run.log -✓ Test passed: output shape torch.Size([64, 12, 307, 1]) matches label shape torch.Size([64, 12, 307, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/METR-LA.yaml - -``` -模型参数量: 1258932 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/METR-LA/2025-12-01_21-52-24/run.log -✓ Test passed: output shape torch.Size([16, 12, 207, 1]) matches label shape torch.Size([16, 12, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD8.yaml - -``` -模型参数量: 1223412 -加载 PEMSD8 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/PEMSD8/2025-12-01_21-52-49/run.log -✓ Test passed: output shape torch.Size([64, 12, 170, 1]) matches label shape torch.Size([64, 12, 170, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD3.yaml - -``` -模型参数量: 1403892 -加载 PEMSD3 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/PEMSD3/2025-12-01_21-53-06/run.log -✓ Test passed: output shape torch.Size([16, 12, 358, 1]) matches label shape torch.Size([16, 12, 358, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/PEMSD7.yaml - -``` -模型参数量: 1907892 -加载 PEMSD7 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/PEMSD7/2025-12-01_21-54-04/run.log -✓ Test passed: output shape torch.Size([16, 12, 883, 1]) matches label shape torch.Size([16, 12, 883, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-inflow.yaml - -``` -模型参数量: 103504579 -加载 NYCBike-InFlow 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/NYCBike-InFlow/2025-12-01_21-55-58/run.log -✓ Test passed: output shape torch.Size([32, 24, 128, 1]) matches label shape torch.Size([32, 24, 128, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/METR-LA.yaml - -``` -模型参数量: 103505369 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/METR-LA/2025-12-01_21-56-29/run.log -✓ Test passed: output shape torch.Size([16, 24, 207, 1]) matches label shape torch.Size([16, 24, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/AirQuality.yaml - -``` -模型参数量: 103503669 -加载 AirQuality 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/AirQuality/2025-12-01_21-56-40/run.log -✓ Test passed: output shape torch.Size([16, 24, 35, 6]) matches label shape torch.Size([16, 24, 35, 6]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/PEMS-BAY.yaml - -``` -模型参数量: 103506549 -加载 PEMS-BAY 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/PEMS-BAY/2025-12-01_21-57-30/run.log -✓ Test passed: output shape torch.Size([16, 24, 325, 1]) matches label shape torch.Size([16, 24, 325, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/SolarEnergy.yaml - -``` -模型参数量: 103504669 -加载 SolarEnergy 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/SolarEnergy/2025-12-01_21-57-55/run.log -✓ Test passed: output shape torch.Size([64, 24, 137, 1]) matches label shape torch.Size([64, 24, 137, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/v2_METR-LA.yaml - -``` -模型参数量: 103524820 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA_v2/METR-LA/2025-12-01_21-58-18/run.log -✓ Test passed: output shape torch.Size([16, 24, 207, 1]) matches label shape torch.Size([16, 24, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-outflow.yaml - -``` -模型参数量: 103504579 -加载 NYCBike-OutFlow 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA/NYCBike-OutFlow/2025-12-01_21-58-29/run.log -✓ Test passed: output shape torch.Size([32, 24, 128, 1]) matches label shape torch.Size([32, 24, 128, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/v2_SolarEnergy.yaml - -``` -模型参数量: 103524120 -加载 SolarEnergy 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/AEPSA_v2/SolarEnergy/2025-12-01_21-58-54/run.log -✓ Test passed: output shape torch.Size([64, 24, 137, 1]) matches label shape torch.Size([64, 24, 137, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/NYCBike-InFlow.yaml - -``` -模型参数量: 35873 -加载 NYCBike-InFlow 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/NYCBike-InFlow/2025-12-01_21-59-55/run.log -✓ Test passed: output shape torch.Size([32, 24, 128, 1]) matches label shape torch.Size([32, 24, 128, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD4.yaml - -``` -模型参数量: 35873 -加载 PEMSD4 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/PEMSD4/2025-12-01_22-00-07/run.log -✓ Test passed: output shape torch.Size([64, 12, 307, 1]) matches label shape torch.Size([64, 12, 307, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/METR-LA.yaml - -``` -模型参数量: 35873 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/METR-LA/2025-12-01_22-00-28/run.log -✓ Test passed: output shape torch.Size([16, 24, 207, 1]) matches label shape torch.Size([16, 24, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/NYCBike-OutFlow.yaml - -``` -模型参数量: 35873 -加载 NYCBike-OutFlow 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/NYCBike-OutFlow/2025-12-01_22-00-44/run.log -✓ Test passed: output shape torch.Size([32, 24, 128, 1]) matches label shape torch.Size([32, 24, 128, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD8.yaml - -``` -模型参数量: 35873 -加载 PEMSD8 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/PEMSD8/2025-12-01_22-00-54/run.log -✓ Test passed: output shape torch.Size([64, 12, 170, 1]) matches label shape torch.Size([64, 12, 170, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD3.yaml - -``` -模型参数量: 35873 -加载 PEMSD3 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/PEMSD3/2025-12-01_22-01-10/run.log -✓ Test passed: output shape torch.Size([16, 12, 358, 1]) matches label shape torch.Size([16, 12, 358, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/SolarEnergy.yaml - -``` -模型参数量: 35873 -加载 SolarEnergy 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/SolarEnergy/2025-12-01_22-01-33/run.log -✓ Test passed: output shape torch.Size([64, 24, 137, 1]) matches label shape torch.Size([64, 24, 137, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/PEMSD7.yaml - -``` -模型参数量: 35873 -加载 PEMSD7 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/PEMSD7/2025-12-01_22-02-05/run.log -✓ Test passed: output shape torch.Size([16, 12, 883, 1]) matches label shape torch.Size([16, 12, 883, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/METR-LA.yaml - -``` -模型参数量: 671644 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DDGCRN/METR-LA/2025-12-01_22-03-35/run.log -✓ Test passed: output shape torch.Size([64, 24, 207, 1]) matches label shape torch.Size([64, 24, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD8.yaml - -``` -模型参数量: 311759 -加载 PEMSD8 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DDGCRN/PEMSD8/2025-12-01_22-03-57/run.log -✓ Test passed: output shape torch.Size([64, 12, 170, 1]) matches label shape torch.Size([64, 12, 170, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD4.yaml - -``` -模型参数量: 37896712 -加载 PEMSD4 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/PEMSD4/2025-12-01_22-04-52/run.log -✓ Test passed: output shape torch.Size([64, 12, 307, 1]) matches label shape torch.Size([64, 12, 307, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/METR-LA.yaml - -``` -模型参数量: 37896712 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/METR-LA/2025-12-01_22-05-06/run.log -✓ Test passed: output shape torch.Size([16, 12, 207, 1]) matches label shape torch.Size([16, 12, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD8.yaml - -``` -模型参数量: 37896712 -加载 PEMSD8 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/PEMSD8/2025-12-01_22-05-33/run.log -✓ Test passed: output shape torch.Size([64, 12, 170, 1]) matches label shape torch.Size([64, 12, 170, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD3.yaml - -``` -模型参数量: 37896712 -加载 PEMSD3 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/PEMSD3/2025-12-01_22-05-49/run.log -✓ Test passed: output shape torch.Size([16, 12, 358, 1]) matches label shape torch.Size([16, 12, 358, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/PEMSD7.yaml - -``` -模型参数量: 615304 -加载 PEMSD7 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/PEMSD7/2025-12-01_22-06-48/run.log -✓ Test passed: output shape torch.Size([16, 12, 883, 1]) matches label shape torch.Size([16, 12, 883, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-inflow.yaml - -``` -模型参数量: 103481647 -加载 NYCBike-InFlow 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/NYCBike-InFlow/2025-12-01_22-09-34/run.log -✓ Test passed: output shape torch.Size([16, 24, 128, 1]) matches label shape torch.Size([16, 24, 128, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/BeijingAirQuality(Deprecated).yaml - -``` -模型参数量: 103815937 -加载 BeijingAirQuality 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/BeijingAirQuality/2025-12-01_22-09-59/run.log -✓ Test passed: output shape torch.Size([16, 24, 7, 3]) matches label shape torch.Size([16, 24, 7, 3]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/METR-LA.yaml - -``` -模型参数量: 103481647 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/METR-LA/2025-12-01_22-10-22/run.log -✓ Test passed: output shape torch.Size([16, 24, 207, 1]) matches label shape torch.Size([16, 24, 207, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/AirQuality.yaml - -``` -模型参数量: 103815973 -加载 AirQuality 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/AirQuality/2025-12-01_22-10-33/run.log -✓ Test passed: output shape torch.Size([16, 24, 35, 3]) matches label shape torch.Size([16, 24, 35, 3]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/PEMS-BAY.yaml - -``` -模型参数量: 103481647 -加载 PEMS-BAY 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/PEMS-BAY/2025-12-01_22-11-23/run.log -✓ Test passed: output shape torch.Size([16, 24, 325, 1]) matches label shape torch.Size([16, 24, 325, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/SolarEnergy.yaml - -``` -模型参数量: 103481647 -加载 SolarEnergy 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/SolarEnergy/2025-12-01_22-11-48/run.log -✓ Test passed: output shape torch.Size([64, 24, 137, 1]) matches label shape torch.Size([64, 24, 137, 1]) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-outflow.yaml - -``` -模型参数量: 103481647 -加载 NYCBike-OutFlow 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/REPST/NYCBike-OutFlow/2025-12-01_22-11-58/run.log -✓ Test passed: output shape torch.Size([16, 24, 128, 1]) matches label shape torch.Size([16, 24, 128, 1]) -``` - - -### FAILED - - -### ERROR - -#### /user/czzhangheng/code/TrafficWheel/config/EXPB/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^^^^ -AttributeError: 'NoneType' object has no attribute 'to' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXPB/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^^^^ -AttributeError: 'NoneType' object has no attribute 'to' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXPB/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^^^^ -AttributeError: 'NoneType' object has no attribute 'to' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXPB/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^^^^ -AttributeError: 'NoneType' object has no attribute 'to' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXPB/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^^^^ -AttributeError: 'NoneType' object has no attribute 'to' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXPB/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^^^^ -AttributeError: 'NoneType' object has no attribute 'to' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STSGCN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 60, in model_selector - return STSGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/STSGCN.py", line 295, in __init__ - self.adj = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STSGCN/get_adj.py", line 10, in get_adj - match args["num_nodes"]: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike-InFlow.yaml - -``` -模型参数量: 146712 -加载 NYCBike-InFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 70, in model_selector - return STID(model_config) - ^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STID/STID.py", line 13, in __init__ - self.embed_dim = model_args["embed_dim"] - ~~~~~~~~~~^^^^^^^^^^^^^ -KeyError: 'embed_dim' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STID/NYCBike-OutFlow.yaml - -``` -模型参数量: 146712 -加载 NYCBike-OutFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAWnet/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 82, in model_selector - return STAWnet(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAWnet/STAWnet.py", line 269, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 29, in get_adj - return adj - ^^^ -UnboundLocalError: cannot access local variable 'adj' where it is not associated with a value -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DCRNN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 50, in model_selector - return DCRNNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DCRNN/dcrnn_model.py", line 123, in __init__ - adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/NYCBike-InFlow.yaml - -``` -模型参数量: 3086208 -加载 NYCBike-InFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/AirQuality.yaml - -``` -模型参数量: 1624752 -加载 AirQuality 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/AirQuality/2025-12-01_21-52-33/run.log -2025/12/01 21:52:33 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/AirQuality/2025-12-01_21-52-33 -2025/12/01 21:52:33 - Training process started - -Train Epoch 1: 0%| | 0/325 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAEFormer/STAEFormer.py", line 195, in forward - x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim) - ^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 125, in forward - return F.linear(input, self.weight, self.bias) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -RuntimeError: mat1 and mat2 shapes cannot be multiplied (13440x1 and 6x24) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/NYCBike-OutFlow.yaml - -``` -模型参数量: 3086208 -加载 NYCBike-OutFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STAEFormer/SolarEnergy.yaml - -``` -模型参数量: 13296192 -加载 SolarEnergy 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/SolarEnergy/2025-12-01_21-53-32/run.log -2025/12/01 21:53:32 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/STAEFormer/SolarEnergy/2025-12-01_21-53-32 -2025/12/01 21:53:32 - Training process started - -Train Epoch 1: 0%| | 0/1970 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STAEFormer/STAEFormer.py", line 195, in forward - x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim) - ^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 125, in forward - return F.linear(input, self.weight, self.bias) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -RuntimeError: mat1 and mat2 shapes cannot be multiplied (52608x1 and 137x24) -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGODE/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 62, in model_selector - return ODEGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGODE/STGODE.py", line 149, in __init__ - num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNCDE/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 44, in model_selector - return make_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNCDE/Make_model.py", line 16, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-InFlow.yaml - -``` -模型参数量: 103513539 -加载 NYCBike-InFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AEPSA/NYCBike-OutFlow.yaml - -``` -模型参数量: 103513539 -加载 NYCBike-OutFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ST_SSL/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 78, in model_selector - return STSSLModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ST_SSL/ST_SSL.py", line 39, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/TCN/AirQuality.yaml - -``` -模型参数量: 36678 -加载 AirQuality 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/AirQuality/2025-12-01_22-00-37/run.log -2025/12/01 22:00:37 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/TCN/AirQuality/2025-12-01_22-00-37 -2025/12/01 22:00:37 - Training process started - -Train Epoch 1: 0%| | 0/325 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TCN/TCN.py", line 43, in forward - x = self.network(x) - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/container.py", line 250, in forward - input = module(input) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/TCN/TCN.py", line 89, in forward - res = x if self.downsample is None else self.downsample(x) - ^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 554, in forward - return self._conv_forward(input, self.weight, self.bias) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 549, in _conv_forward - return F.conv2d( - ^^^^^^^^^ -RuntimeError: Given groups=1, weight of size [32, 6, 1, 1], expected input[16, 1, 35, 24] to have 6 channels, but got 1 channels instead -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/SD.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD8.yaml - -``` -模型参数量: 235788 -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 18, in get_dataloader - return EXP_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/EXPdataloader.py", line 8, in get_dataloader - data = load_st_dataset(args["type"], args["sample"]) # [T, N, F] - ~~~~^^^^^^^^ -KeyError: 'type' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/EXP/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 74, in model_selector - return EXP(model_config) - ^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/EXP/EXP32.py", line 114, in __init__ - self.horizon = args["horizon"] # 预测步长 - ~~~~^^^^^^^^^^^ -KeyError: 'horizon' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7(L).yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/Hainan.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DDGCRN/PEMSD7(M).yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 36, in model_selector - return DDGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/DDGCRN/DDGCRN.py", line 41, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/NYCBike-InFlow.yaml - -``` -模型参数量: 37897240 -加载 NYCBike-InFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/AirQuality.yaml - -``` -模型参数量: 37897240 -加载 AirQuality 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/AirQuality/2025-12-01_22-05-16/run.log -2025/12/01 22:05:16 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/AirQuality/2025-12-01_22-05-16 -2025/12/01 22:05:16 - Training process started - -Train Epoch 1: 0%| | 0/325 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 180, in _run_epoch - loss = self.loss(output, label) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/loss.py", line 128, in forward - return F.l1_loss(input, target, reduction=self.reduction) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/functional.py", line 3810, in l1_loss - expanded_input, expanded_target = torch.broadcast_tensors(input, target) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/functional.py", line 76, in broadcast_tensors - return _VF.broadcast_tensors(tensors) # type: ignore[attr-defined] - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -RuntimeError: The size of tensor a (12) must match the size of tensor b (24) at non-singleton dimension 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/NYCBike-OutFlow.yaml - -``` -模型参数量: 37897240 -加载 NYCBike-OutFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/DSANET/SolarEnergy.yaml - -``` -模型参数量: 37897240 -加载 SolarEnergy 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/SolarEnergy/2025-12-01_22-06-16/run.log -2025/12/01 22:06:16 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/DSANET/SolarEnergy/2025-12-01_22-06-16 -2025/12/01 22:06:16 - Training process started - -Train Epoch 1: 0%| | 0/1970 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 180, in _run_epoch - loss = self.loss(output, label) - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/loss.py", line 128, in forward - return F.l1_loss(input, target, reduction=self.reduction) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/functional.py", line 3810, in l1_loss - expanded_input, expanded_target = torch.broadcast_tensors(input, target) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/functional.py", line 76, in broadcast_tensors - return _VF.broadcast_tensors(tensors) # type: ignore[attr-defined] - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -RuntimeError: The size of tensor a (12) must match the size of tensor b (24) at non-singleton dimension 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STFGNN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 58, in model_selector - return STFGNN(model_config) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STFGNN/STFGNN.py", line 343, in __init__ - adj = torch.tensor(get_adj(args)) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 53, in __init__ - self.input_dim = args["input_dim"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'input_dim' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/AGCRN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 40, in model_selector - return AGCRN(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/AGCRN/AGCRN.py", line 52, in __init__ - self.num_node = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGNRDE/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 80, in model_selector - return make_nrde_model(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGNRDE/Make_model.py", line 17, in make_model - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/PEMS-BAY_paper.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 84, in model_selector - return REPST(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/REPST/repst.py", line 24, in __init__ - self.word_choice = GumbelSoftmax(configs['word_num']) - ~~~~~~~^^^^^^^^^^^^ -KeyError: 'word_num' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-InFlow.yaml - -``` -模型参数量: 103481647 -加载 NYCBike-InFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/REPST/NYCBike-OutFlow.yaml - -``` -模型参数量: 103481647 -加载 NYCBike-OutFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STIDGCN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 68, in model_selector - return STIDGCN(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STIDGCN/STIDGCN.py", line 337, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/PDG2SEQ/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 64, in model_selector - return PDG2Seq(model_config) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/PDG2SEQ/PDG2Seqb.py", line 226, in __init__ - self.num_nodes = args["num_nodes"] - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/NLT/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 42, in model_selector - return HierAttnLstm(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/NLT/HierAttnLstm.py", line 10, in __init__ - args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/NYCBike-InFlow.yaml - -``` -模型参数量: 4 -加载 NYCBike-InFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD4.yaml - -``` -模型参数量: 4 -加载 PEMSD4 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD4/2025-12-01_22-14-54/run.log -2025/12/01 22:14:54 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD4/2025-12-01_22-14-54 -2025/12/01 22:14:54 - Training process started - -Train Epoch 1: 0%| | 0/160 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward - drift = self.drift[n] if self.drift is not None else None - ~~~~~~~~~~^^^ -IndexError: index 1 is out of bounds for dimension 0 with size 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/METR-LA.yaml - -``` -模型参数量: 4 -加载 METR-LA 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/METR-LA/2025-12-01_22-15-07/run.log -2025/12/01 22:15:07 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/METR-LA/2025-12-01_22-15-07 -2025/12/01 22:15:07 - Training process started - -Train Epoch 1: 0%| | 0/1285 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward - drift = self.drift[n] if self.drift is not None else None - ~~~~~~~~~~^^^ -IndexError: index 1 is out of bounds for dimension 0 with size 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/AirQuality.yaml - -``` -模型参数量: 4 -加载 AirQuality 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/AirQuality/2025-12-01_22-15-15/run.log -2025/12/01 22:15:15 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/AirQuality/2025-12-01_22-15-15 -2025/12/01 22:15:15 - Training process started - -Train Epoch 1: 0%| | 0/325 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward - drift = self.drift[n] if self.drift is not None else None - ~~~~~~~~~~^^^ -IndexError: index 1 is out of bounds for dimension 0 with size 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/NYCBike-OutFlow.yaml - -``` -模型参数量: 4 -加载 NYCBike-OutFlow 数据集中... -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 22, in main - train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( - ^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/loader_selector.py", line 20, in get_dataloader - return normal_loader(config, normalizer, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 14, in get_dataloader - x, y = _prepare_data_with_windows(data, args, single) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 43, in _prepare_data_with_windows - x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/dataloader/PeMSDdataloader.py", line 64, in _add_time_features - return np.concatenate([data, time_day, time_week], axis=-1) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 2, the array at index 0 has size 128 and the array at index 1 has size 1024 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD8.yaml - -``` -模型参数量: 4 -加载 PEMSD8 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD8/2025-12-01_22-15-31/run.log -2025/12/01 22:15:31 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD8/2025-12-01_22-15-31 -2025/12/01 22:15:31 - Training process started - -Train Epoch 1: 0%| | 0/168 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward - drift = self.drift[n] if self.drift is not None else None - ~~~~~~~~~~^^^ -IndexError: index 1 is out of bounds for dimension 0 with size 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7(L).yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 52, in model_selector - return ARIMA(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 9, in __init__ - self.p = args["p"] # 自回归阶数 - ~~~~^^^^^ -KeyError: 'p' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD3.yaml - -``` -模型参数量: 4 -加载 PEMSD3 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD3/2025-12-01_22-15-53/run.log -2025/12/01 22:15:53 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD3/2025-12-01_22-15-53 -2025/12/01 22:15:53 - Training process started - -Train Epoch 1: 0%| | 0/982 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward - drift = self.drift[n] if self.drift is not None else None - ~~~~~~~~~~^^^ -IndexError: index 1 is out of bounds for dimension 0 with size 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/SolarEnergy.yaml - -``` -模型参数量: 4 -加载 SolarEnergy 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/SolarEnergy/2025-12-01_22-16-18/run.log -2025/12/01 22:16:18 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/SolarEnergy/2025-12-01_22-16-18 -2025/12/01 22:16:18 - Training process started - -Train Epoch 1: 0%| | 0/1970 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward - drift = self.drift[n] if self.drift is not None else None - ~~~~~~~~~~^^^ -IndexError: index 1 is out of bounds for dimension 0 with size 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/Hainan.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 52, in model_selector - return ARIMA(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 9, in __init__ - self.p = args["p"] # 自回归阶数 - ~~~~^^^^^ -KeyError: 'p' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7.yaml - -``` -模型参数量: 4 -加载 PEMSD7 数据集中... -Create Log File in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD7/2025-12-01_22-16-56/run.log -2025/12/01 22:16:56 - Experiment log path in: /user/czzhangheng/code/TrafficWheel/experiments/ARIMA/PEMSD7/2025-12-01_22-16-56 -2025/12/01 22:16:56 - Training process started - -Train Epoch 1: 0%| | 0/1058 [00:00 - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 46, in main - trainer.train() - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 266, in train - train_epoch_loss = self.train_epoch(epoch) - ^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 244, in train_epoch - return self._run_epoch(epoch, self.train_loader, "train") - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/trainer/Trainer.py", line 179, in _run_epoch - output = self.model(data).to(self.device) - ^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/miniconda/envs/TrafficWheel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 40, in forward - drift = self.drift[n] if self.drift is not None else None - ~~~~~~~~~~^^^ -IndexError: index 1 is out of bounds for dimension 0 with size 1 -``` - -#### /user/czzhangheng/code/TrafficWheel/config/ARIMA/PEMSD7(M).yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 52, in model_selector - return ARIMA(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/ARIMA/ARIMA.py", line 9, in __init__ - self.p = args["p"] # 自回归阶数 - ~~~~^^^^^ -KeyError: 'p' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STMLP/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 66, in model_selector - return STMLP(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STMLP/STMLP.py", line 188, in __init__ - self.adj_mx = get_adj(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 10, in get_adj - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/MegaCRN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 76, in model_selector - return MegaCRNModel(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/MegaCRN/MegaCRNModel.py", line 32, in __init__ - num_nodes=args["num_nodes"], - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/GWN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 56, in model_selector - return gwnet(model_config) - ^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/GWN/GraphWaveNet.py", line 91, in __init__ - torch.randn(args["num_nodes"], 10, device=args["device"]) - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/NYCBike-InFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD4.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/METR-LA.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/AirQuality.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/NYCBike-OutFlow.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD8.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD3.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/SolarEnergy.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - -#### /user/czzhangheng/code/TrafficWheel/config/STGCN/PEMSD7.yaml - -``` - -Traceback (most recent call last): - File "/user/czzhangheng/code/TrafficWheel/run.py", line 66, in - main() - File "/user/czzhangheng/code/TrafficWheel/run.py", line 19, in main - model = init.init_model(args) - ^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/initializer.py", line 14, in init_model - model = model_selector(args).to(device) - ^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/model_selector.py", line 48, in model_selector - return STGCNChebGraphConv(model_config) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/model/STGCN/models.py", line 31, in __init__ - gso = get_gso(args) - ^^^^^^^^^^^^^ - File "/user/czzhangheng/code/TrafficWheel/utils/get_adj.py", line 33, in get_gso - match args['num_nodes']: - ~~~~^^^^^^^^^^^^^ -KeyError: 'num_nodes' -``` - diff --git a/train.py b/train.py new file mode 100644 index 0000000..9d58921 --- /dev/null +++ b/train.py @@ -0,0 +1,63 @@ +import yaml +import torch + +import utils.initializer as init +from dataloader.loader_selector import get_dataloader +from trainer.trainer_selector import select_trainer + +def run(config): + init.init_seed(config["basic"]["seed"]) + model = init.init_model(config) + train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( + config, normalizer=config["data"]["normalizer"], single=False + ) + loss = init.init_loss(config, scaler) + optimizer, lr_scheduler = init.init_optimizer(model, config["train"]) + init.create_logs(config) + trainer = select_trainer( + model, + loss, optimizer, + train_loader, val_loader, test_loader, scaler, + config, + lr_scheduler, extra_data, + ) + + # 开始训练 + match config["basic"]["mode"]: + case "train": + trainer.train() + case "test": + model.load_state_dict( + torch.load( + f"./pre-trained/{config['basic']['model']}/{config['basic']['dataset']}.pth", + map_location=config["basic"]["device"], + weights_only=True, + ) + ) + trainer.test( + model.to(config["basic"]["device"]), + trainer.args, test_loader, scaler, + trainer.logger, + ) + case _: + raise ValueError(f"Unsupported mode: {config['basic']['mode']}") + + +if __name__ == "__main__": + # 指定模型 + model_list = ["HI"] + # 指定数据集 + dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] + device = "cuda:0" # 指定设备 + seed = 2023 # 随机种子 + for model in model_list: + for dataset in dataset_list: + config_path = f"./config/{model}/{dataset}.yaml" + with open(config_path, "r") as file: + config = yaml.safe_load(file) + config["basic"]["device"] = device + config["basic"]["seed"] = seed + print(f"\nRunning {model} on {dataset} with seed {seed} on {device}") + print(f"config: {config}") + run(config) + diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 2bd7e6e..4bd82a4 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -8,125 +8,31 @@ from utils.logger import get_logger from utils.loss_function import all_metrics from tqdm import tqdm - -class TrainingStats: - """记录训练过程中的统计信息""" - - def __init__(self, device): - self.device = device - self.reset() - - def reset(self): - """重置所有统计数据""" - self.gpu_mem_usage_list = [] - self.cpu_mem_usage_list = [] - self.train_time_list = [] - self.infer_time_list = [] - self.total_iters = 0 - self.start_time = None - self.end_time = None - - def start_training(self): - """记录训练开始时间""" - self.start_time = time.time() - - def end_training(self): - """记录训练结束时间""" - self.end_time = time.time() - - def record_step_time(self, duration, mode): - """记录单步耗时和总迭代次数""" - if mode == "train": - self.train_time_list.append(duration) - else: - self.infer_time_list.append(duration) - self.total_iters += 1 - - def record_memory_usage(self): - """记录当前 GPU 和 CPU 内存占用""" - process = psutil.Process(os.getpid()) - cpu_mem = process.memory_info().rss / (1024**2) - - if torch.cuda.is_available(): - gpu_mem = torch.cuda.max_memory_allocated(device=self.device) / (1024**2) - torch.cuda.reset_peak_memory_stats(device=self.device) - else: - gpu_mem = 0.0 - - self.cpu_mem_usage_list.append(cpu_mem) - self.gpu_mem_usage_list.append(gpu_mem) - - def _calculate_average(self, values_list): - """安全计算平均值,避免除零错误""" - return sum(values_list) / len(values_list) if values_list else 0 - - def report(self, logger): - """在训练结束时输出汇总统计""" - if not self.start_time or not self.end_time: - logger.warning("TrainingStats: start/end time not recorded properly.") - return - - total_time = self.end_time - self.start_time - avg_gpu_mem = self._calculate_average(self.gpu_mem_usage_list) - avg_cpu_mem = self._calculate_average(self.cpu_mem_usage_list) - avg_train_time = self._calculate_average(self.train_time_list) - avg_infer_time = self._calculate_average(self.infer_time_list) - iters_per_sec = self.total_iters / total_time if total_time > 0 else 0 - - logger.info("===== Training Summary =====") - logger.info(f"Total training time: {total_time:.2f} s") - logger.info(f"Total iterations: {self.total_iters}") - logger.info(f"Average iterations per second: {iters_per_sec:.2f}") - logger.info(f"Average GPU Memory Usage: {avg_gpu_mem:.2f} MB") - logger.info(f"Average CPU Memory Usage: {avg_cpu_mem:.2f} MB") - if avg_train_time: - logger.info(f"Average training step time: {avg_train_time * 1000:.2f} ms") - if avg_infer_time: - logger.info(f"Average inference step time: {avg_infer_time * 1000:.2f} ms") - - class Trainer: """模型训练器,负责整个训练流程的管理""" - def __init__( - self, - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler=None, - ): + def __init__(self, model, loss, optimizer, + train_loader, val_loader, test_loader, scaler, + args, lr_scheduler=None,): # 设备和基本参数 + self.config = args self.device = args["basic"]["device"] train_args = args["train"] - # 模型和训练相关组件 self.model = model self.loss = loss self.optimizer = optimizer self.lr_scheduler = lr_scheduler - # 数据加载器 self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader - # 数据处理工具 self.scaler = scaler self.args = train_args - - # 统计信息 - self.train_per_epoch = len(train_loader) - self.val_per_epoch = len(val_loader) if val_loader else 0 - # 初始化路径、日志和统计 self._initialize_paths(train_args) self._initialize_logger(train_args) - self._initialize_stats() def _initialize_paths(self, args): """初始化模型保存路径""" @@ -138,24 +44,14 @@ class Trainer: """初始化日志记录器""" if not os.path.isdir(args["log_dir"]) and not args["debug"]: os.makedirs(args["log_dir"], exist_ok=True) - self.logger = get_logger( - args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"] - ) + self.logger = get_logger(args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"]) self.logger.info(f"Experiment log path in: {args['log_dir']}") - - def _initialize_stats(self): - """初始化统计信息记录器""" - self.stats = TrainingStats(device=self.device) def _run_epoch(self, epoch, dataloader, mode): """运行一个训练/验证/测试epoch""" # 设置模型模式和是否进行优化 - if mode == "train": - self.model.train() - optimizer_step = True - else: - self.model.eval() - optimizer_step = False + if mode == "train": self.model.train(); optimizer_step = True + else: self.model.eval(); optimizer_step = False # 初始化变量 total_loss = 0 @@ -169,73 +65,42 @@ class Trainer: total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" ) - for _, (data, target) in progress_bar: - # 记录步骤开始时间 - start_time = time.time() - - # 前向传播 + # 转移数据 + data = data.to(self.device) + target = target.to(self.device) label = target[..., : self.args["output_dim"]] - output = self.model(data).to(self.device) - # if output.shape != label.shape: - # import sys - # print(f"[Wrong]: Output shape: {output.shape}, Label shape: {label.shape}") - # sys.exit(1) - # else: - # import sys - # print(f"[Right]: Output shape: {output.shape}, Label shape: {label.shape}") - # sys.exit(0) + # 计算loss和反归一化loss + output = self.model(data) loss = self.loss(output, label) - - # 反归一化 d_output = self.scaler.inverse_transform(output) d_label = self.scaler.inverse_transform(label) - - # 反向传播和优化(仅在训练模式) - if optimizer_step and self.optimizer is not None: - self.optimizer.zero_grad() - loss.backward() - - # 梯度裁剪(如果需要) - if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.args["max_grad_norm"] - ) - self.optimizer.step() - - # 反归一化的loss d_loss = self.loss(d_output, d_label) - - # 记录步骤时间和内存使用 - step_time = time.time() - start_time - self.stats.record_step_time(step_time, mode) - # 累积损失和预测结果 total_loss += d_loss.item() y_pred.append(d_output.detach().cpu()) y_true.append(d_label.detach().cpu()) - + # 反向传播和优化(仅在训练模式) + if optimizer_step and self.optimizer is not None: + self.optimizer.zero_grad() + loss.backward() + # 梯度裁剪(如果需要) + if self.args["grad_norm"]: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) + self.optimizer.step() # 更新进度条 progress_bar.set_postfix(loss=d_loss.item()) # 合并所有批次的预测结果 y_pred = torch.cat(y_pred, dim=0) y_true = torch.cat(y_true, dim=0) - - # 计算平均损失 + # 计算损失并记录指标 avg_loss = total_loss / len(dataloader) - - # 计算并记录指标 - mae, rmse, mape = all_metrics( - y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"] - ) + mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) self.logger.info( - f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" + f"Epoch #{epoch:02d}: {mode.capitalize():<5} " + f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" ) - - # 记录内存使用情况 - self.stats.record_memory_usage() - return avg_loss def train_epoch(self, epoch): @@ -248,28 +113,22 @@ class Trainer: return self._run_epoch(epoch, self.test_loader, "test") def train(self): - """执行完整的训练流程""" - # 初始化最佳模型和损失记录 + # 初始化记录 best_model, best_test_model = None, None best_loss, best_test_loss = float("inf"), float("inf") not_improved_count = 0 - # 开始训练 - self.stats.start_training() self.logger.info("Training process started") - # 训练循环 for epoch in range(1, self.args["epochs"] + 1): # 训练、验证和测试一个epoch train_epoch_loss = self.train_epoch(epoch) val_epoch_loss = self.val_epoch(epoch) test_epoch_loss = self.test_epoch(epoch) - # 检查梯度爆炸 if train_epoch_loss > 1e6: self.logger.warning("Gradient explosion detected. Ending...") break - # 更新最佳验证模型 if val_epoch_loss < best_loss: best_loss = val_epoch_loss @@ -278,29 +137,18 @@ class Trainer: self.logger.info("Best validation model saved!") else: not_improved_count += 1 - - # 检查早停条件 + # 早停 if self._should_early_stop(not_improved_count): break - # 更新最佳测试模型 if test_epoch_loss < best_test_loss: best_test_loss = test_epoch_loss best_test_model = copy.deepcopy(self.model.state_dict()) - # 保存最佳模型 if not self.args["debug"]: self._save_best_models(best_model, best_test_model) - - # 结束训练并输出统计信息 - self.stats.end_training() - self.stats.report(self.logger) - # 最终评估 self._finalize_training(best_model, best_test_model) - - # 输出模型参数量 - self._log_model_params() def _should_early_stop(self, not_improved_count): """检查是否满足早停条件""" @@ -331,20 +179,35 @@ class Trainer: def _finalize_training(self, best_model, best_test_model): self.model.load_state_dict(best_model) self.logger.info("Testing on best validation model") - self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) - + self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) self.model.load_state_dict(best_test_model) self.logger.info("Testing on best test model") - self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) + self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) @staticmethod def test(model, args, data_loader, scaler, logger, path=None): """对模型进行评估并输出性能指标""" + # 确定设备信息 + device = None + output_dim = None + # 处理不同的参数格式 + if isinstance(args, dict): + if "basic" in args: + # 完整配置情况 + device = args["basic"]["device"] + output_dim = args["train"]["output_dim"] + else: + # 只有train_args情况,从模型获取设备 + device = next(model.parameters()).device + output_dim = args["output_dim"] + else: + raise ValueError(f"Unsupported args type: {type(args)}") + # 加载模型检查点(如果提供了路径) if path: checkpoint = torch.load(path) model.load_state_dict(checkpoint["state_dict"]) - model.to(args["basic"]["device"]) + model.to(device) # 设置为评估模式 model.eval() @@ -355,27 +218,40 @@ class Trainer: # 不计算梯度的情况下进行预测 with torch.no_grad(): for data, target in data_loader: - label = target[..., : args["output_dim"]] + # 将数据和标签移动到指定设备 + data = data.to(device) + target = target.to(device) + + label = target[..., : output_dim] output = model(data) y_pred.append(output.detach().cpu()) y_true.append(label.detach().cpu()) - d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) + # 获取metrics参数 + if "basic" in args: + # 完整配置情况 + mae_thresh = args["train"]["mae_thresh"] + mape_thresh = args["train"]["mape_thresh"] + else: + # 只有train_args情况 + mae_thresh = args["mae_thresh"] + mape_thresh = args["mape_thresh"] + # 计算并记录每个时间步的指标 for t in range(d_y_true.shape[1]): mae, rmse, mape = all_metrics( d_y_pred[:, t, ...], d_y_true[:, t, ...], - args["mae_thresh"], - args["mape_thresh"], + mae_thresh, + mape_thresh, ) logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") # 计算并记录平均指标 - mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]) + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh) logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") @staticmethod diff --git a/trainer/Trainer_bk.py b/trainer/Trainer_bk.py new file mode 100755 index 0000000..ee6e388 --- /dev/null +++ b/trainer/Trainer_bk.py @@ -0,0 +1,420 @@ +import math +import os +import time +import copy +import psutil +import torch +from utils.logger import get_logger +from utils.loss_function import all_metrics +from tqdm import tqdm + + +# class TrainingStats: +# """记录训练过程中的统计信息""" + +# def __init__(self, device): +# self.device = device +# self.reset() + +# def reset(self): +# """重置所有统计数据""" +# self.gpu_mem_usage_list = [] +# self.cpu_mem_usage_list = [] +# self.train_time_list = [] +# self.infer_time_list = [] +# self.total_iters = 0 +# self.start_time = None +# self.end_time = None + +# def start_training(self): +# """记录训练开始时间""" +# self.start_time = time.time() + +# def end_training(self): +# """记录训练结束时间""" +# self.end_time = time.time() + +# def record_step_time(self, duration, mode): +# """记录单步耗时和总迭代次数""" +# if mode == "train": +# self.train_time_list.append(duration) +# else: +# self.infer_time_list.append(duration) +# self.total_iters += 1 + +# def record_memory_usage(self): +# """记录当前 GPU 和 CPU 内存占用""" +# process = psutil.Process(os.getpid()) +# cpu_mem = process.memory_info().rss / (1024**2) + +# if torch.cuda.is_available(): +# gpu_mem = torch.cuda.max_memory_allocated(device=self.device) / (1024**2) +# torch.cuda.reset_peak_memory_stats(device=self.device) +# else: +# gpu_mem = 0.0 + +# self.cpu_mem_usage_list.append(cpu_mem) +# self.gpu_mem_usage_list.append(gpu_mem) + +# def _calculate_average(self, values_list): +# """安全计算平均值,避免除零错误""" +# return sum(values_list) / len(values_list) if values_list else 0 + +# def report(self, logger): +# """在训练结束时输出汇总统计""" +# if not self.start_time or not self.end_time: +# logger.warning("TrainingStats: start/end time not recorded properly.") +# return + +# total_time = self.end_time - self.start_time +# avg_gpu_mem = self._calculate_average(self.gpu_mem_usage_list) +# avg_cpu_mem = self._calculate_average(self.cpu_mem_usage_list) +# avg_train_time = self._calculate_average(self.train_time_list) +# avg_infer_time = self._calculate_average(self.infer_time_list) +# iters_per_sec = self.total_iters / total_time if total_time > 0 else 0 + +# logger.info("===== Training Summary =====") +# logger.info(f"Total training time: {total_time:.2f} s") +# logger.info(f"Total iterations: {self.total_iters}") +# logger.info(f"Average iterations per second: {iters_per_sec:.2f}") +# logger.info(f"Average GPU Memory Usage: {avg_gpu_mem:.2f} MB") +# logger.info(f"Average CPU Memory Usage: {avg_cpu_mem:.2f} MB") +# if avg_train_time: +# logger.info(f"Average training step time: {avg_train_time * 1000:.2f} ms") +# if avg_infer_time: +# logger.info(f"Average inference step time: {avg_infer_time * 1000:.2f} ms") + + +class Trainer: + """模型训练器,负责整个训练流程的管理""" + + def __init__( + self, + model, + loss, + optimizer, + train_loader, + val_loader, + test_loader, + scaler, + args, + lr_scheduler=None, + ): + # 设备和基本参数 + self.device = args["basic"]["device"] + self.config = args # 保存完整的配置参数 + train_args = args["train"] + + # 模型和训练相关组件 + self.model = model + self.loss = loss + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + # 数据加载器 + self.train_loader = train_loader + self.val_loader = val_loader + self.test_loader = test_loader + + # 数据处理工具 + self.scaler = scaler + self.args = train_args + + # 统计信息 + # self.train_per_epoch = len(train_loader) + # self.val_per_epoch = len(val_loader) if val_loader else 0 + + # 初始化路径、日志和统计 + self._initialize_paths(train_args) + self._initialize_logger(train_args) + self._initialize_stats() + + def _initialize_paths(self, args): + """初始化模型保存路径""" + self.best_path = os.path.join(args["log_dir"], "best_model.pth") + self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") + self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") + + def _initialize_logger(self, args): + """初始化日志记录器""" + if not os.path.isdir(args["log_dir"]) and not args["debug"]: + os.makedirs(args["log_dir"], exist_ok=True) + self.logger = get_logger( + args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"] + ) + self.logger.info(f"Experiment log path in: {args['log_dir']}") + + # def _initialize_stats(self): + # """初始化统计信息记录器""" + # self.stats = TrainingStats(device=self.device) + + def _run_epoch(self, epoch, dataloader, mode): + """运行一个训练/验证/测试epoch""" + # 设置模型模式和是否进行优化 + if mode == "train": + self.model.train() + optimizer_step = True + else: + self.model.eval() + optimizer_step = False + + # 初始化变量 + total_loss = 0 + epoch_time = time.time() + y_pred, y_true = [], [] + + # 训练/验证循环 + with torch.set_grad_enabled(optimizer_step): + progress_bar = tqdm( + enumerate(dataloader), + total=len(dataloader), + desc=f"{mode.capitalize()} Epoch {epoch}" + ) + + for _, (data, target) in progress_bar: + # 记录步骤开始时间 + start_time = time.time() + + # 将数据和标签移动到指定设备 + data = data.to(self.device) + target = target.to(self.device) + + # 前向传播 + label = target[..., : self.args["output_dim"]] + output = self.model(data) + # if output.shape != label.shape: + # import sys + # print(f"[Wrong]: Output shape: {output.shape}, Label shape: {label.shape}") + # sys.exit(1) + # else: + # import sys + # print(f"[Right]: Output shape: {output.shape}, Label shape: {label.shape}") + # sys.exit(0) + loss = self.loss(output, label) + + # 反归一化 + d_output = self.scaler.inverse_transform(output) + d_label = self.scaler.inverse_transform(label) + + # 反向传播和优化(仅在训练模式) + if optimizer_step and self.optimizer is not None: + self.optimizer.zero_grad() + loss.backward() + + # 梯度裁剪(如果需要) + if self.args["grad_norm"]: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.args["max_grad_norm"] + ) + self.optimizer.step() + + # 反归一化的loss + d_loss = self.loss(d_output, d_label) + + # 记录步骤时间和内存使用 + # step_time = time.time() - start_time + # self.stats.record_step_time(step_time, mode) + + # 累积损失和预测结果 + total_loss += d_loss.item() + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) + + # 更新进度条 + progress_bar.set_postfix(loss=d_loss.item()) + + # 合并所有批次的预测结果 + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) + + # 计算平均损失 + avg_loss = total_loss / len(dataloader) + + # 计算并记录指标 + mae, rmse, mape = all_metrics( + y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"] + ) + self.logger.info( + f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" + ) + + # 记录内存使用情况 + # self.stats.record_memory_usage() + + return avg_loss + + def train_epoch(self, epoch): + return self._run_epoch(epoch, self.train_loader, "train") + + def val_epoch(self, epoch): + return self._run_epoch(epoch, self.val_loader or self.test_loader, "val") + + def test_epoch(self, epoch): + return self._run_epoch(epoch, self.test_loader, "test") + + def train(self): + """执行完整的训练流程""" + # 初始化最佳模型和损失记录 + best_model, best_test_model = None, None + best_loss, best_test_loss = float("inf"), float("inf") + not_improved_count = 0 + + # 开始训练 + # self.stats.start_training() + self.logger.info("Training process started") + + # 训练循环 + for epoch in range(1, self.args["epochs"] + 1): + # 训练、验证和测试一个epoch + train_epoch_loss = self.train_epoch(epoch) + val_epoch_loss = self.val_epoch(epoch) + test_epoch_loss = self.test_epoch(epoch) + + # 检查梯度爆炸 + if train_epoch_loss > 1e6: + self.logger.warning("Gradient explosion detected. Ending...") + break + + # 更新最佳验证模型 + if val_epoch_loss < best_loss: + best_loss = val_epoch_loss + not_improved_count = 0 + best_model = copy.deepcopy(self.model.state_dict()) + self.logger.info("Best validation model saved!") + else: + not_improved_count += 1 + + # 检查早停条件 + if self._should_early_stop(not_improved_count): + break + + # 更新最佳测试模型 + if test_epoch_loss < best_test_loss: + best_test_loss = test_epoch_loss + best_test_model = copy.deepcopy(self.model.state_dict()) + + # 保存最佳模型 + if not self.args["debug"]: + self._save_best_models(best_model, best_test_model) + + # 结束训练并输出统计信息 + # self.stats.end_training() + # self.stats.report(self.logger) + + # 最终评估 + self._finalize_training(best_model, best_test_model) + + # 输出模型参数量 + self._log_model_params() + + def _should_early_stop(self, not_improved_count): + """检查是否满足早停条件""" + if ( + self.args["early_stop"] + and not_improved_count == self.args["early_stop_patience"] + ): + self.logger.info( + f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." + ) + return True + return False + + def _save_best_models(self, best_model, best_test_model): + """保存最佳模型到文件""" + torch.save(best_model, self.best_path) + torch.save(best_test_model, self.best_test_path) + self.logger.info( + f"Best models saved at {self.best_path} and {self.best_test_path}" + ) + + def _log_model_params(self): + """输出模型可训练参数数量""" + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + + + def _finalize_training(self, best_model, best_test_model): + self.model.load_state_dict(best_model) + self.logger.info("Testing on best validation model") + self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) + + self.model.load_state_dict(best_test_model) + self.logger.info("Testing on best test model") + self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) + + @staticmethod + def test(model, args, data_loader, scaler, logger, path=None): + """对模型进行评估并输出性能指标""" + # 确定设备信息 + device = None + output_dim = None + + # 处理不同的参数格式 + if isinstance(args, dict): + if "basic" in args: + # 完整配置情况 + device = args["basic"]["device"] + output_dim = args["train"]["output_dim"] + else: + # 只有train_args情况 + # 从模型获取设备 + device = next(model.parameters()).device + output_dim = args["output_dim"] + else: + raise ValueError(f"Unsupported args type: {type(args)}") + + # 加载模型检查点(如果提供了路径) + if path: + checkpoint = torch.load(path) + model.load_state_dict(checkpoint["state_dict"]) + model.to(device) + + # 设置为评估模式 + model.eval() + + # 收集预测和真实标签 + y_pred, y_true = [], [] + + # 不计算梯度的情况下进行预测 + with torch.no_grad(): + for data, target in data_loader: + # 将数据和标签移动到指定设备 + data = data.to(device) + target = target.to(device) + + label = target[..., : output_dim] + output = model(data) + y_pred.append(output.detach().cpu()) + y_true.append(label.detach().cpu()) + + + d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) + + # 获取metrics参数 + if "basic" in args: + # 完整配置情况 + mae_thresh = args["train"]["mae_thresh"] + mape_thresh = args["train"]["mape_thresh"] + else: + # 只有train_args情况 + mae_thresh = args["mae_thresh"] + mape_thresh = args["mape_thresh"] + + # 计算并记录每个时间步的指标 + for t in range(d_y_true.shape[1]): + mae, rmse, mape = all_metrics( + d_y_pred[:, t, ...], + d_y_true[:, t, ...], + mae_thresh, + mape_thresh, + ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + # 计算并记录平均指标 + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + @staticmethod + def _compute_sampling_threshold(global_step, k): + return k / (k + math.exp(global_step / k)) diff --git a/trainer/Trainer_old.py b/trainer/Trainer_old.py deleted file mode 100755 index bd49b29..0000000 --- a/trainer/Trainer_old.py +++ /dev/null @@ -1,229 +0,0 @@ -import math -import os -import time -import copy -from tqdm import tqdm - -import torch -from utils.logger import get_logger -from utils.loss_function import all_metrics -from utils.training_stats import TrainingStats - - -class Trainer: - def __init__( - self, - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler=None, - ): - self.model = model - self.loss = loss - self.optimizer = optimizer - self.train_loader = train_loader - self.val_loader = val_loader - self.test_loader = test_loader - self.scaler = scaler - self.args = args - self.lr_scheduler = lr_scheduler - self.train_per_epoch = len(train_loader) - self.val_per_epoch = len(val_loader) if val_loader else 0 - - # Paths for saving models and logs - self.best_path = os.path.join(args["log_dir"], "best_model.pth") - self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") - self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") - - # Initialize logger - if not os.path.isdir(args["log_dir"]) and not args["debug"]: - os.makedirs(args["log_dir"], exist_ok=True) - self.logger = get_logger( - args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"] - ) - self.logger.info(f"Experiment log path in: {args['log_dir']}") - # Stats tracker - self.stats = TrainingStats(device=args["device"]) - - def _run_epoch(self, epoch, dataloader, mode): - if mode == "train": - self.model.train() - optimizer_step = True - else: - self.model.eval() - optimizer_step = False - - total_loss = 0 - epoch_time = time.time() - - with torch.set_grad_enabled(optimizer_step): - with tqdm( - total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" - ) as pbar: - for batch_idx, (data, target) in enumerate(dataloader): - start_time = time.time() - label = target[..., : self.args["output_dim"]] - output = self.model(data).to(self.args["device"]) - - if self.args["real_value"]: - output = self.scaler.inverse_transform(output) - - loss = self.loss(output, label) - if optimizer_step and self.optimizer is not None: - self.optimizer.zero_grad() - loss.backward() - - if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.args["max_grad_norm"] - ) - self.optimizer.step() - - step_time = time.time() - start_time - self.stats.record_step_time(step_time, mode) - total_loss += loss.item() - - if mode == "train" and (batch_idx + 1) % self.args["log_step"] == 0: - self.logger.info( - f"Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}" - ) - - # 更新 tqdm 的进度 - pbar.update(1) - pbar.set_postfix(loss=loss.item()) - - avg_loss = total_loss / len(dataloader) - self.logger.info( - f"{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s" - ) - # 记录内存 - self.stats.record_memory_usage() - return avg_loss - - def train_epoch(self, epoch): - return self._run_epoch(epoch, self.train_loader, "train") - - def val_epoch(self, epoch): - return self._run_epoch(epoch, self.val_loader or self.test_loader, "val") - - def test_epoch(self, epoch): - return self._run_epoch(epoch, self.test_loader, "test") - - def train(self): - best_model, best_test_model = None, None - best_loss, best_test_loss = float("inf"), float("inf") - not_improved_count = 0 - - self.stats.start_training() - self.logger.info("Training process started") - for epoch in range(1, self.args["epochs"] + 1): - train_epoch_loss = self.train_epoch(epoch) - val_epoch_loss = self.val_epoch(epoch) - test_epoch_loss = self.test_epoch(epoch) - - if train_epoch_loss > 1e6: - self.logger.warning("Gradient explosion detected. Ending...") - break - - if val_epoch_loss < best_loss: - best_loss = val_epoch_loss - not_improved_count = 0 - best_model = copy.deepcopy(self.model.state_dict()) - self.logger.info("Best validation model saved!") - else: - not_improved_count += 1 - - if ( - self.args["early_stop"] - and not_improved_count == self.args["early_stop_patience"] - ): - self.logger.info( - f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." - ) - break - - if test_epoch_loss < best_test_loss: - best_test_loss = test_epoch_loss - best_test_model = copy.deepcopy(self.model.state_dict()) - - if not self.args["debug"]: - torch.save(best_model, self.best_path) - torch.save(best_test_model, self.best_test_path) - self.logger.info( - f"Best models saved at {self.best_path} and {self.best_test_path}" - ) - - # 输出统计与参数 - self.stats.end_training() - self.stats.report(self.logger) - try: - total_params = sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) - self.logger.info(f"Trainable params: {total_params}") - except Exception: - pass - self._finalize_training(best_model, best_test_model) - - def _finalize_training(self, best_model, best_test_model): - self.model.load_state_dict(best_model) - self.logger.info("Testing on best validation model") - self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) - - self.model.load_state_dict(best_test_model) - self.logger.info("Testing on best test model") - self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) - - @staticmethod - def test(model, args, data_loader, scaler, logger, path=None): - if path: - checkpoint = torch.load(path) - model.load_state_dict(checkpoint["state_dict"]) - model.to(args["device"]) - - model.eval() - y_pred, y_true = [], [] - - with torch.no_grad(): - for data, target in data_loader: - label = target[..., : args["output_dim"]] - output = model(data) - y_pred.append(output) - y_true.append(label) - - if args["real_value"]: - y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) - else: - y_pred = torch.cat(y_pred, dim=0) - y_true = torch.cat(y_true, dim=0) - - # 你在这里需要把y_pred和y_true保存下来 - # torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1] - # torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1] - - for t in range(y_true.shape[1]): - mae, rmse, mape = all_metrics( - y_pred[:, t, ...], - y_true[:, t, ...], - args["mae_thresh"], - args["mape_thresh"], - ) - logger.info( - f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) - - mae, rmse, mape = all_metrics( - y_pred, y_true, args["mae_thresh"], args["mape_thresh"] - ) - logger.info( - f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) - - @staticmethod - def _compute_sampling_threshold(global_step, k): - return k / (k + math.exp(global_step / k)) diff --git a/utils/initializer.py b/utils/initializer.py index 7bee2be..183bfd3 100755 --- a/utils/initializer.py +++ b/utils/initializer.py @@ -9,9 +9,9 @@ import os import yaml -def init_model(args): - device = args["device"] - model = model_selector(args).to(device) +def init_model(config): + device = config["basic"]["device"] + model = model_selector(config).to(device) for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p)