TrafficWheel/config/args_parser.py

54 lines
2.0 KiB
Python
Executable File

import argparse
import yaml
def parse_args():
parser = argparse.ArgumentParser(description="Model Training and Testing")
parser.add_argument(
"--config", type=str, required=True, help="Path to the configuration file"
)
args = parser.parse_args()
# Load YAML configuration
if args.config:
with open(args.config, "r") as file:
config = yaml.safe_load(file)
else:
raise ValueError("Configuration file path must be provided using --config")
# Update configuration with command-line arguments
# Merge 'basic' configuration into the root dictionary
# config.update(config.get('basic', {}))
# Add adaptive configuration based on external commands
if "data" in config and "type" in config["data"]:
config["data"]["type"] = config["basic"].get("dataset", config["data"]["type"])
if "model" in config and "type" in config["model"]:
config["model"]["type"] = config["basic"].get("model", config["model"]["type"])
if "model" in config and "rnn_units" in config["model"]:
config["model"]["rnn_units"] = config["basic"].get(
"rnn", config["model"]["rnn_units"]
)
if "model" in config and "embed_dim" in config["model"]:
config["model"]["embed_dim"] = config["basic"].get(
"emb", config["model"]["embed_dim"]
)
if "data" in config and "sample" in config["data"]:
config["data"]["sample"] = config["basic"].get(
"sample", config["data"]["sample"]
)
if "train" in config and "device" in config["train"]:
config["train"]["device"] = config["basic"].get(
"device", config["train"]["device"]
)
if "train" in config and "debug" in config["train"]:
config["train"]["debug"] = config["basic"].get(
"debug", config["train"]["debug"]
)
if "cuda" in config:
config["cuda"] = config["basic"].get("cuda", config["cuda"])
if "mode" in config:
config["mode"] = config["basic"].get("mode", config["mode"])
return config