41 lines
1.9 KiB
Python
Executable File
41 lines
1.9 KiB
Python
Executable File
import argparse
|
|
import yaml
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Model Training and Testing')
|
|
parser.add_argument('--config', type=str, required=True, help='Path to the configuration file')
|
|
args = parser.parse_args()
|
|
|
|
# Load YAML configuration
|
|
if args.config:
|
|
with open(args.config, 'r') as file:
|
|
config = yaml.safe_load(file)
|
|
else:
|
|
raise ValueError("Configuration file path must be provided using --config")
|
|
|
|
# Update configuration with command-line arguments
|
|
# Merge 'basic' configuration into the root dictionary
|
|
# config.update(config.get('basic', {}))
|
|
|
|
# Add adaptive configuration based on external commands
|
|
if 'data' in config and 'type' in config['data']:
|
|
config['data']['type'] = config['basic'].get('dataset', config['data']['type'])
|
|
if 'model' in config and 'type' in config['model']:
|
|
config['model']['type'] = config['basic'].get('model', config['model']['type'])
|
|
if 'model' in config and 'rnn_units' in config['model']:
|
|
config['model']['rnn_units'] = config['basic'].get('rnn', config['model']['rnn_units'])
|
|
if 'model' in config and 'embed_dim' in config['model']:
|
|
config['model']['embed_dim'] = config['basic'].get('emb', config['model']['embed_dim'])
|
|
if 'data' in config and 'sample' in config['data']:
|
|
config['data']['sample'] = config['basic'].get('sample', config['data']['sample'])
|
|
if 'train' in config and 'device' in config['train']:
|
|
config['train']['device'] = config['basic'].get('device', config['train']['device'])
|
|
if 'train' in config and 'debug' in config['train']:
|
|
config['train']['debug'] = config['basic'].get('debug', config['train']['debug'])
|
|
if 'cuda' in config:
|
|
config['cuda'] = config['basic'].get('cuda', config['cuda'])
|
|
if 'mode' in config:
|
|
config['mode'] = config['basic'].get('mode', config['mode'])
|
|
|
|
return config
|