54 lines
2.0 KiB
Python
Executable File
54 lines
2.0 KiB
Python
Executable File
import argparse
|
|
import yaml
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Model Training and Testing")
|
|
parser.add_argument(
|
|
"--config", type=str, required=True, help="Path to the configuration file"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Load YAML configuration
|
|
if args.config:
|
|
with open(args.config, "r") as file:
|
|
config = yaml.safe_load(file)
|
|
else:
|
|
raise ValueError("Configuration file path must be provided using --config")
|
|
|
|
# Update configuration with command-line arguments
|
|
# Merge 'basic' configuration into the root dictionary
|
|
# config.update(config.get('basic', {}))
|
|
|
|
# Add adaptive configuration based on external commands
|
|
if "data" in config and "type" in config["data"]:
|
|
config["data"]["type"] = config["basic"].get("dataset", config["data"]["type"])
|
|
if "model" in config and "type" in config["model"]:
|
|
config["model"]["type"] = config["basic"].get("model", config["model"]["type"])
|
|
if "model" in config and "rnn_units" in config["model"]:
|
|
config["model"]["rnn_units"] = config["basic"].get(
|
|
"rnn", config["model"]["rnn_units"]
|
|
)
|
|
if "model" in config and "embed_dim" in config["model"]:
|
|
config["model"]["embed_dim"] = config["basic"].get(
|
|
"emb", config["model"]["embed_dim"]
|
|
)
|
|
if "data" in config and "sample" in config["data"]:
|
|
config["data"]["sample"] = config["basic"].get(
|
|
"sample", config["data"]["sample"]
|
|
)
|
|
if "train" in config and "device" in config["train"]:
|
|
config["train"]["device"] = config["basic"].get(
|
|
"device", config["train"]["device"]
|
|
)
|
|
if "train" in config and "debug" in config["train"]:
|
|
config["train"]["debug"] = config["basic"].get(
|
|
"debug", config["train"]["debug"]
|
|
)
|
|
if "cuda" in config:
|
|
config["cuda"] = config["basic"].get("cuda", config["cuda"])
|
|
if "mode" in config:
|
|
config["mode"] = config["basic"].get("mode", config["mode"])
|
|
|
|
return config
|