#!/usr/bin/env python3 import os from collections import defaultdict from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedMap yaml = YAML() yaml.preserve_quotes = True yaml.indent(mapping=2, sequence=4, offset=2) # 允许的 data keys DATA_ALLOWED_KEYS = { "lag", "horizon", "num_nodes", "steps_per_day", "days_per_week", "test_ratio", "val_ratio", "batch_size", "input_dim", "column_wise", "normalizer", } # 全局默认值 GLOBAL_DEFAULTS = { "lag": 24, "horizon": 24, "num_nodes": 1, "steps_per_day": 24, "days_per_week": 7, "test_ratio": 0.2, "val_ratio": 0.2, "batch_size": 16, "input_dim": 1, "column_wise": False, "normalizer": "std", } # train全局默认值 GLOBAL_TRAIN_DEFAULTS = { "output_dim": 1 } def load_yaml(path): try: with open(path, "r", encoding="utf-8") as f: return yaml.load(f) except Exception: return None def collect_dataset_defaults(base="."): """ 收集每个数据集 data 的 key 默认值,以及 train.output_dim 默认值 """ data_defaults = defaultdict(dict) train_output_defaults = dict() for root, _, files in os.walk(base): for name in files: if not (name.endswith(".yaml") or name.endswith(".yml")): continue path = os.path.join(root, name) cm = load_yaml(path) if not isinstance(cm, CommentedMap): continue basic = cm.get("basic") if not isinstance(basic, dict): continue dataset = basic.get("dataset") if dataset is None: continue ds = str(dataset) # data 默认值 data_sec = cm.get("data") if isinstance(data_sec, dict): for key in DATA_ALLOWED_KEYS: if key not in data_defaults[ds] and key in data_sec and data_sec[key] is not None: data_defaults[ds][key] = data_sec[key] # train.output_dim 默认值 train_sec = cm.get("train") if isinstance(train_sec, dict): val = train_sec.get("output_dim") if val is not None and ds not in train_output_defaults: train_output_defaults[ds] = val return data_defaults, train_output_defaults def ensure_basic_seed(cm: CommentedMap, path: str): if "basic" not in cm or not isinstance(cm["basic"], dict): cm["basic"] = CommentedMap() basic = cm["basic"] if "seed" not in basic: basic["seed"] = 2023 print(f"[ADD] {path}: basic.seed = 2023") def fill_data_defaults(cm: CommentedMap, data_defaults: dict, path: str): if "data" not in cm or not isinstance(cm["data"], dict): cm["data"] = CommentedMap() data_sec = cm["data"] basic = cm.get("basic", {}) dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None for key in sorted(DATA_ALLOWED_KEYS): if key in data_sec and data_sec[key] is not None: continue if dataset and dataset in data_defaults and key in data_defaults[dataset]: chosen = data_defaults[dataset][key] src = f"default_from_dataset[{dataset}]" else: chosen = GLOBAL_DEFAULTS[key] src = "GLOBAL_DEFAULTS" data_sec[key] = chosen print(f"[FILL] {path}: data.{key} <- {src} ({repr(chosen)})") def merge_test_log_into_train(cm: CommentedMap, path: str): """ 将 test 和 log 的 key 合并到 train,并删除 test 和 log 同时确保 train.debug 存在 """ train_sec = cm.setdefault("train", CommentedMap()) for section in ["test", "log"]: if section in cm and isinstance(cm[section], dict): for k, v in cm[section].items(): if k not in train_sec: train_sec[k] = v print(f"[MERGE] {path}: train.{k} <- {section}.{k} ({repr(v)})") del cm[section] print(f"[DEL] {path}: deleted section '{section}'") # train.debug if "debug" not in train_sec: train_sec["debug"] = False print(f"[ADD] {path}: train.debug = False") def fill_train_output_dim(cm: CommentedMap, train_output_defaults: dict, path: str): train_sec = cm.setdefault("train", CommentedMap()) if "output_dim" not in train_sec or train_sec["output_dim"] is None: basic = cm.get("basic", {}) dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None if dataset and dataset in train_output_defaults: val = train_output_defaults[dataset] src = f"default_from_dataset[{dataset}]" else: val = GLOBAL_TRAIN_DEFAULTS["output_dim"] src = "GLOBAL_TRAIN_DEFAULTS" train_sec["output_dim"] = val print(f"[FILL] {path}: train.output_dim <- {src} ({val})") def sync_train_batch_size(cm: CommentedMap, path: str): """ 如果 train.batch_size 与 data.batch_size 不一致,以 data 为准 """ data_sec = cm.get("data", {}) train_sec = cm.get("train", {}) data_bs = data_sec.get("batch_size") train_bs = train_sec.get("batch_size") if data_bs is not None and train_bs != data_bs: train_sec["batch_size"] = data_bs print(f"[SYNC] {path}: train.batch_size corrected to match data.batch_size ({data_bs})") def sort_subkeys_and_insert_blanklines(cm: CommentedMap): for sec in list(cm.keys()): if isinstance(cm[sec], dict): sorted_cm = CommentedMap() for k in sorted(cm[sec].keys()): sorted_cm[k] = cm[sec][k] cm[sec] = sorted_cm keys = list(cm.keys()) for i, k in enumerate(keys): if i == 0: try: cm.yaml_set_comment_before_after_key(k, before=None) except Exception: pass else: try: cm.yaml_set_comment_before_after_key(k, before="\n") except Exception: pass def process_all(base="."): print(">> Collecting dataset defaults ...") data_defaults, train_output_defaults = collect_dataset_defaults(base) print(">> Collected data defaults per dataset:") for ds, kv in data_defaults.items(): print(f" - {ds}: {kv}") print(">> Collected train.output_dim defaults per dataset:") for ds, val in train_output_defaults.items(): print(f" - {ds}: output_dim = {val}") for root, _, files in os.walk(base): for name in files: if not (name.endswith(".yaml") or name.endswith(".yml")): continue path = os.path.join(root, name) cm = load_yaml(path) if not isinstance(cm, CommentedMap): print(f"[SKIP] {path}: top-level not mapping or load failed") continue ensure_basic_seed(cm, path) fill_data_defaults(cm, data_defaults, path) merge_test_log_into_train(cm, path) fill_train_output_dim(cm, train_output_defaults, path) sync_train_batch_size(cm, path) # <-- 新增逻辑 sort_subkeys_and_insert_blanklines(cm) try: with open(path, "w", encoding="utf-8") as f: yaml.dump(cm, f) print(f"[OK] Written: {path}") except Exception as e: print(f"[ERROR] Write failed {path}: {e}") if __name__ == "__main__": process_all(".")