TrafficWheel/config/args_parser.py

46 lines
2.0 KiB
Python
Executable File

import argparse
import yaml
def parse_args():
parser = argparse.ArgumentParser(description='Model Training and Testing')
parser.add_argument('--dataset', default='PEMSD8', type=str)
parser.add_argument('--mode', default='train', type=str)
parser.add_argument('--device', default='cuda:0', type=str, help='Indices of GPUs')
parser.add_argument('--debug', default=False, type=eval)
parser.add_argument('--model', default='GWN', type=str)
parser.add_argument('--cuda', default=True, type=bool)
parser.add_argument('--sample', default=1, type=int)
parser.add_argument('--emb', default=12, type=int)
parser.add_argument('--rnn', default=64, type=int)
args = parser.parse_args()
# Load YAML configuration
config_file = f'./config/{args.model}/{args.dataset}.yaml'
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
config['data']['type'] = args.dataset
config['model']['type'] = args.model
config['model']['rnn_units'] = args.rnn
config['model']['embed_dim'] = args.emb
config['data']['sample'] = args.sample
config['data']['input_dim'] = config['model']['input_dim']
config['data']['output_dim'] = config['model']['output_dim']
config['data']['batch_size'] = config['train']['batch_size']
config['model']['num_nodes'] = config['data']['num_nodes']
config['model']['horizon'] = config['data']['horizon']
config['model']['default_graph'] = config['data']['default_graph']
config['train']['device'] = args.device
config['train']['debug'] = args.debug
config['train']['log_step'] = config['log']['log_step']
config['train']['output_dim'] = config['model']['output_dim']
config['train']['mae_thresh'] = config['test']['mae_thresh']
config['train']['mape_thresh'] = config['test']['mape_thresh']
config['cuda'] = args.cuda
config['mode'] = args.mode
config['device'] = args.device
config['model']['device'] = config['device']
return config