import argparse import yaml def parse_args(): parser = argparse.ArgumentParser(description='Model Training and Testing') parser.add_argument('--dataset', default='PEMSD8', type=str) parser.add_argument('--mode', default='train', type=str) parser.add_argument('--device', default='cuda:0', type=str, help='Indices of GPUs') parser.add_argument('--debug', default=False, type=eval) parser.add_argument('--model', default='GWN', type=str) parser.add_argument('--cuda', default=True, type=bool) parser.add_argument('--sample', default=1, type=int) parser.add_argument('--emb', default=12, type=int) parser.add_argument('--rnn', default=64, type=int) args = parser.parse_args() # Load YAML configuration config_file = f'./config/{args.model}/{args.dataset}.yaml' with open(config_file, 'r') as file: config = yaml.safe_load(file) config['data']['type'] = args.dataset config['model']['type'] = args.model config['model']['rnn_units'] = args.rnn config['model']['embed_dim'] = args.emb config['data']['sample'] = args.sample config['data']['input_dim'] = config['model']['input_dim'] config['data']['output_dim'] = config['model']['output_dim'] config['data']['batch_size'] = config['train']['batch_size'] config['model']['num_nodes'] = config['data']['num_nodes'] config['model']['horizon'] = config['data']['horizon'] config['model']['default_graph'] = config['data']['default_graph'] config['train']['device'] = args.device config['train']['debug'] = args.debug config['train']['log_step'] = config['log']['log_step'] config['train']['output_dim'] = config['model']['output_dim'] config['train']['mae_thresh'] = config['test']['mae_thresh'] config['train']['mape_thresh'] = config['test']['mape_thresh'] config['cuda'] = args.cuda config['mode'] = args.mode config['device'] = args.device config['model']['device'] = config['device'] return config