235 lines
7.4 KiB
Python
235 lines
7.4 KiB
Python
#!/usr/bin/env python3
|
||
import os
|
||
from collections import defaultdict
|
||
from ruamel.yaml import YAML
|
||
from ruamel.yaml.comments import CommentedMap
|
||
|
||
yaml = YAML()
|
||
yaml.preserve_quotes = True
|
||
yaml.indent(mapping=2, sequence=4, offset=2)
|
||
|
||
# 允许的 data keys
|
||
DATA_ALLOWED_KEYS = {
|
||
"lag",
|
||
"horizon",
|
||
"num_nodes",
|
||
"steps_per_day",
|
||
"days_per_week",
|
||
"test_ratio",
|
||
"val_ratio",
|
||
"batch_size",
|
||
"input_dim",
|
||
"column_wise",
|
||
"normalizer",
|
||
}
|
||
|
||
# 全局默认值
|
||
GLOBAL_DEFAULTS = {
|
||
"lag": 24,
|
||
"horizon": 24,
|
||
"num_nodes": 1,
|
||
"steps_per_day": 24,
|
||
"days_per_week": 7,
|
||
"test_ratio": 0.2,
|
||
"val_ratio": 0.2,
|
||
"batch_size": 16,
|
||
"input_dim": 1,
|
||
"column_wise": False,
|
||
"normalizer": "std",
|
||
}
|
||
|
||
# train全局默认值
|
||
GLOBAL_TRAIN_DEFAULTS = {
|
||
"output_dim": 1
|
||
}
|
||
|
||
|
||
def load_yaml(path):
|
||
try:
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
return yaml.load(f)
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
def collect_dataset_defaults(base="."):
|
||
"""
|
||
收集每个数据集 data 的 key 默认值,以及 train.output_dim 默认值
|
||
"""
|
||
data_defaults = defaultdict(dict)
|
||
train_output_defaults = dict()
|
||
|
||
for root, _, files in os.walk(base):
|
||
for name in files:
|
||
if not (name.endswith(".yaml") or name.endswith(".yml")):
|
||
continue
|
||
path = os.path.join(root, name)
|
||
cm = load_yaml(path)
|
||
if not isinstance(cm, CommentedMap):
|
||
continue
|
||
basic = cm.get("basic")
|
||
if not isinstance(basic, dict):
|
||
continue
|
||
dataset = basic.get("dataset")
|
||
if dataset is None:
|
||
continue
|
||
ds = str(dataset)
|
||
|
||
# data 默认值
|
||
data_sec = cm.get("data")
|
||
if isinstance(data_sec, dict):
|
||
for key in DATA_ALLOWED_KEYS:
|
||
if key not in data_defaults[ds] and key in data_sec and data_sec[key] is not None:
|
||
data_defaults[ds][key] = data_sec[key]
|
||
|
||
# train.output_dim 默认值
|
||
train_sec = cm.get("train")
|
||
if isinstance(train_sec, dict):
|
||
val = train_sec.get("output_dim")
|
||
if val is not None and ds not in train_output_defaults:
|
||
train_output_defaults[ds] = val
|
||
|
||
return data_defaults, train_output_defaults
|
||
|
||
|
||
def ensure_basic_seed(cm: CommentedMap, path: str):
|
||
if "basic" not in cm or not isinstance(cm["basic"], dict):
|
||
cm["basic"] = CommentedMap()
|
||
basic = cm["basic"]
|
||
if "seed" not in basic:
|
||
basic["seed"] = 2023
|
||
print(f"[ADD] {path}: basic.seed = 2023")
|
||
|
||
|
||
def fill_data_defaults(cm: CommentedMap, data_defaults: dict, path: str):
|
||
if "data" not in cm or not isinstance(cm["data"], dict):
|
||
cm["data"] = CommentedMap()
|
||
data_sec = cm["data"]
|
||
|
||
basic = cm.get("basic", {})
|
||
dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None
|
||
|
||
for key in sorted(DATA_ALLOWED_KEYS):
|
||
if key in data_sec and data_sec[key] is not None:
|
||
continue
|
||
if dataset and dataset in data_defaults and key in data_defaults[dataset]:
|
||
chosen = data_defaults[dataset][key]
|
||
src = f"default_from_dataset[{dataset}]"
|
||
else:
|
||
chosen = GLOBAL_DEFAULTS[key]
|
||
src = "GLOBAL_DEFAULTS"
|
||
data_sec[key] = chosen
|
||
print(f"[FILL] {path}: data.{key} <- {src} ({repr(chosen)})")
|
||
|
||
|
||
def merge_test_log_into_train(cm: CommentedMap, path: str):
|
||
"""
|
||
将 test 和 log 的 key 合并到 train,并删除 test 和 log
|
||
同时确保 train.debug 存在
|
||
"""
|
||
train_sec = cm.setdefault("train", CommentedMap())
|
||
|
||
for section in ["test", "log"]:
|
||
if section in cm and isinstance(cm[section], dict):
|
||
for k, v in cm[section].items():
|
||
if k not in train_sec:
|
||
train_sec[k] = v
|
||
print(f"[MERGE] {path}: train.{k} <- {section}.{k} ({repr(v)})")
|
||
del cm[section]
|
||
print(f"[DEL] {path}: deleted section '{section}'")
|
||
|
||
# train.debug
|
||
if "debug" not in train_sec:
|
||
train_sec["debug"] = False
|
||
print(f"[ADD] {path}: train.debug = False")
|
||
|
||
|
||
def fill_train_output_dim(cm: CommentedMap, train_output_defaults: dict, path: str):
|
||
train_sec = cm.setdefault("train", CommentedMap())
|
||
if "output_dim" not in train_sec or train_sec["output_dim"] is None:
|
||
basic = cm.get("basic", {})
|
||
dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None
|
||
if dataset and dataset in train_output_defaults:
|
||
val = train_output_defaults[dataset]
|
||
src = f"default_from_dataset[{dataset}]"
|
||
else:
|
||
val = GLOBAL_TRAIN_DEFAULTS["output_dim"]
|
||
src = "GLOBAL_TRAIN_DEFAULTS"
|
||
train_sec["output_dim"] = val
|
||
print(f"[FILL] {path}: train.output_dim <- {src} ({val})")
|
||
|
||
|
||
def sync_train_batch_size(cm: CommentedMap, path: str):
|
||
"""
|
||
如果 train.batch_size 与 data.batch_size 不一致,以 data 为准
|
||
"""
|
||
data_sec = cm.get("data", {})
|
||
train_sec = cm.get("train", {})
|
||
data_bs = data_sec.get("batch_size")
|
||
train_bs = train_sec.get("batch_size")
|
||
|
||
if data_bs is not None and train_bs != data_bs:
|
||
train_sec["batch_size"] = data_bs
|
||
print(f"[SYNC] {path}: train.batch_size corrected to match data.batch_size ({data_bs})")
|
||
|
||
|
||
def sort_subkeys_and_insert_blanklines(cm: CommentedMap):
|
||
for sec in list(cm.keys()):
|
||
if isinstance(cm[sec], dict):
|
||
sorted_cm = CommentedMap()
|
||
for k in sorted(cm[sec].keys()):
|
||
sorted_cm[k] = cm[sec][k]
|
||
cm[sec] = sorted_cm
|
||
|
||
keys = list(cm.keys())
|
||
for i, k in enumerate(keys):
|
||
if i == 0:
|
||
try:
|
||
cm.yaml_set_comment_before_after_key(k, before=None)
|
||
except Exception:
|
||
pass
|
||
else:
|
||
try:
|
||
cm.yaml_set_comment_before_after_key(k, before="\n")
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def process_all(base="."):
|
||
print(">> Collecting dataset defaults ...")
|
||
data_defaults, train_output_defaults = collect_dataset_defaults(base)
|
||
print(">> Collected data defaults per dataset:")
|
||
for ds, kv in data_defaults.items():
|
||
print(f" - {ds}: {kv}")
|
||
print(">> Collected train.output_dim defaults per dataset:")
|
||
for ds, val in train_output_defaults.items():
|
||
print(f" - {ds}: output_dim = {val}")
|
||
|
||
for root, _, files in os.walk(base):
|
||
for name in files:
|
||
if not (name.endswith(".yaml") or name.endswith(".yml")):
|
||
continue
|
||
path = os.path.join(root, name)
|
||
cm = load_yaml(path)
|
||
if not isinstance(cm, CommentedMap):
|
||
print(f"[SKIP] {path}: top-level not mapping or load failed")
|
||
continue
|
||
|
||
ensure_basic_seed(cm, path)
|
||
fill_data_defaults(cm, data_defaults, path)
|
||
merge_test_log_into_train(cm, path)
|
||
fill_train_output_dim(cm, train_output_defaults, path)
|
||
sync_train_batch_size(cm, path) # <-- 新增逻辑
|
||
sort_subkeys_and_insert_blanklines(cm)
|
||
|
||
try:
|
||
with open(path, "w", encoding="utf-8") as f:
|
||
yaml.dump(cm, f)
|
||
print(f"[OK] Written: {path}")
|
||
except Exception as e:
|
||
print(f"[ERROR] Write failed {path}: {e}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
process_all(".")
|