TrafficWheel/config/args_parser.py

41 lines
1.9 KiB
Python
Executable File

import argparse
import yaml
def parse_args():
parser = argparse.ArgumentParser(description='Model Training and Testing')
parser.add_argument('--config', type=str, required=True, help='Path to the configuration file')
args = parser.parse_args()
# Load YAML configuration
if args.config:
with open(args.config, 'r') as file:
config = yaml.safe_load(file)
else:
raise ValueError("Configuration file path must be provided using --config")
# Update configuration with command-line arguments
# Merge 'basic' configuration into the root dictionary
# config.update(config.get('basic', {}))
# Add adaptive configuration based on external commands
if 'data' in config and 'type' in config['data']:
config['data']['type'] = config['basic'].get('dataset', config['data']['type'])
if 'model' in config and 'type' in config['model']:
config['model']['type'] = config['basic'].get('model', config['model']['type'])
if 'model' in config and 'rnn_units' in config['model']:
config['model']['rnn_units'] = config['basic'].get('rnn', config['model']['rnn_units'])
if 'model' in config and 'embed_dim' in config['model']:
config['model']['embed_dim'] = config['basic'].get('emb', config['model']['embed_dim'])
if 'data' in config and 'sample' in config['data']:
config['data']['sample'] = config['basic'].get('sample', config['data']['sample'])
if 'train' in config and 'device' in config['train']:
config['train']['device'] = config['basic'].get('device', config['train']['device'])
if 'train' in config and 'debug' in config['train']:
config['train']['debug'] = config['basic'].get('debug', config['train']['debug'])
if 'cuda' in config:
config['cuda'] = config['basic'].get('cuda', config['cuda'])
if 'mode' in config:
config['mode'] = config['basic'].get('mode', config['mode'])
return config