TrafficWheel/config/tmp.py

235 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import os
from collections import defaultdict
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
yaml = YAML()
yaml.preserve_quotes = True
yaml.indent(mapping=2, sequence=4, offset=2)
# 允许的 data keys
DATA_ALLOWED_KEYS = {
"lag",
"horizon",
"num_nodes",
"steps_per_day",
"days_per_week",
"test_ratio",
"val_ratio",
"batch_size",
"input_dim",
"column_wise",
"normalizer",
}
# 全局默认值
GLOBAL_DEFAULTS = {
"lag": 24,
"horizon": 24,
"num_nodes": 1,
"steps_per_day": 24,
"days_per_week": 7,
"test_ratio": 0.2,
"val_ratio": 0.2,
"batch_size": 16,
"input_dim": 1,
"column_wise": False,
"normalizer": "std",
}
# train全局默认值
GLOBAL_TRAIN_DEFAULTS = {
"output_dim": 1
}
def load_yaml(path):
try:
with open(path, "r", encoding="utf-8") as f:
return yaml.load(f)
except Exception:
return None
def collect_dataset_defaults(base="."):
"""
收集每个数据集 data 的 key 默认值,以及 train.output_dim 默认值
"""
data_defaults = defaultdict(dict)
train_output_defaults = dict()
for root, _, files in os.walk(base):
for name in files:
if not (name.endswith(".yaml") or name.endswith(".yml")):
continue
path = os.path.join(root, name)
cm = load_yaml(path)
if not isinstance(cm, CommentedMap):
continue
basic = cm.get("basic")
if not isinstance(basic, dict):
continue
dataset = basic.get("dataset")
if dataset is None:
continue
ds = str(dataset)
# data 默认值
data_sec = cm.get("data")
if isinstance(data_sec, dict):
for key in DATA_ALLOWED_KEYS:
if key not in data_defaults[ds] and key in data_sec and data_sec[key] is not None:
data_defaults[ds][key] = data_sec[key]
# train.output_dim 默认值
train_sec = cm.get("train")
if isinstance(train_sec, dict):
val = train_sec.get("output_dim")
if val is not None and ds not in train_output_defaults:
train_output_defaults[ds] = val
return data_defaults, train_output_defaults
def ensure_basic_seed(cm: CommentedMap, path: str):
if "basic" not in cm or not isinstance(cm["basic"], dict):
cm["basic"] = CommentedMap()
basic = cm["basic"]
if "seed" not in basic:
basic["seed"] = 2023
print(f"[ADD] {path}: basic.seed = 2023")
def fill_data_defaults(cm: CommentedMap, data_defaults: dict, path: str):
if "data" not in cm or not isinstance(cm["data"], dict):
cm["data"] = CommentedMap()
data_sec = cm["data"]
basic = cm.get("basic", {})
dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None
for key in sorted(DATA_ALLOWED_KEYS):
if key in data_sec and data_sec[key] is not None:
continue
if dataset and dataset in data_defaults and key in data_defaults[dataset]:
chosen = data_defaults[dataset][key]
src = f"default_from_dataset[{dataset}]"
else:
chosen = GLOBAL_DEFAULTS[key]
src = "GLOBAL_DEFAULTS"
data_sec[key] = chosen
print(f"[FILL] {path}: data.{key} <- {src} ({repr(chosen)})")
def merge_test_log_into_train(cm: CommentedMap, path: str):
"""
将 test 和 log 的 key 合并到 train并删除 test 和 log
同时确保 train.debug 存在
"""
train_sec = cm.setdefault("train", CommentedMap())
for section in ["test", "log"]:
if section in cm and isinstance(cm[section], dict):
for k, v in cm[section].items():
if k not in train_sec:
train_sec[k] = v
print(f"[MERGE] {path}: train.{k} <- {section}.{k} ({repr(v)})")
del cm[section]
print(f"[DEL] {path}: deleted section '{section}'")
# train.debug
if "debug" not in train_sec:
train_sec["debug"] = False
print(f"[ADD] {path}: train.debug = False")
def fill_train_output_dim(cm: CommentedMap, train_output_defaults: dict, path: str):
train_sec = cm.setdefault("train", CommentedMap())
if "output_dim" not in train_sec or train_sec["output_dim"] is None:
basic = cm.get("basic", {})
dataset = str(basic.get("dataset")) if basic and "dataset" in basic else None
if dataset and dataset in train_output_defaults:
val = train_output_defaults[dataset]
src = f"default_from_dataset[{dataset}]"
else:
val = GLOBAL_TRAIN_DEFAULTS["output_dim"]
src = "GLOBAL_TRAIN_DEFAULTS"
train_sec["output_dim"] = val
print(f"[FILL] {path}: train.output_dim <- {src} ({val})")
def sync_train_batch_size(cm: CommentedMap, path: str):
"""
如果 train.batch_size 与 data.batch_size 不一致,以 data 为准
"""
data_sec = cm.get("data", {})
train_sec = cm.get("train", {})
data_bs = data_sec.get("batch_size")
train_bs = train_sec.get("batch_size")
if data_bs is not None and train_bs != data_bs:
train_sec["batch_size"] = data_bs
print(f"[SYNC] {path}: train.batch_size corrected to match data.batch_size ({data_bs})")
def sort_subkeys_and_insert_blanklines(cm: CommentedMap):
for sec in list(cm.keys()):
if isinstance(cm[sec], dict):
sorted_cm = CommentedMap()
for k in sorted(cm[sec].keys()):
sorted_cm[k] = cm[sec][k]
cm[sec] = sorted_cm
keys = list(cm.keys())
for i, k in enumerate(keys):
if i == 0:
try:
cm.yaml_set_comment_before_after_key(k, before=None)
except Exception:
pass
else:
try:
cm.yaml_set_comment_before_after_key(k, before="\n")
except Exception:
pass
def process_all(base="."):
print(">> Collecting dataset defaults ...")
data_defaults, train_output_defaults = collect_dataset_defaults(base)
print(">> Collected data defaults per dataset:")
for ds, kv in data_defaults.items():
print(f" - {ds}: {kv}")
print(">> Collected train.output_dim defaults per dataset:")
for ds, val in train_output_defaults.items():
print(f" - {ds}: output_dim = {val}")
for root, _, files in os.walk(base):
for name in files:
if not (name.endswith(".yaml") or name.endswith(".yml")):
continue
path = os.path.join(root, name)
cm = load_yaml(path)
if not isinstance(cm, CommentedMap):
print(f"[SKIP] {path}: top-level not mapping or load failed")
continue
ensure_basic_seed(cm, path)
fill_data_defaults(cm, data_defaults, path)
merge_test_log_into_train(cm, path)
fill_train_output_dim(cm, train_output_defaults, path)
sync_train_batch_size(cm, path) # <-- 新增逻辑
sort_subkeys_and_insert_blanklines(cm)
try:
with open(path, "w", encoding="utf-8") as f:
yaml.dump(cm, f)
print(f"[OK] Written: {path}")
except Exception as e:
print(f"[ERROR] Write failed {path}: {e}")
if __name__ == "__main__":
process_all(".")