206 lines
8.6 KiB
Python
Executable File
206 lines
8.6 KiB
Python
Executable File
import os
|
|
import sys
|
|
file_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
print(file_dir)
|
|
sys.path.append(file_dir)
|
|
|
|
import torch
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
import argparse
|
|
import configparser
|
|
import time
|
|
|
|
from model.BasicTrainer_cde import Trainer
|
|
from lib.TrainInits import init_seed
|
|
from lib.dataloader import get_dataloader_cde
|
|
from lib.TrainInits import print_model_parameters
|
|
import os
|
|
from os.path import join
|
|
from Make_model import make_model
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
#*************************************************************************#
|
|
Mode = 'train'
|
|
DEBUG = 'False'
|
|
DATASET = 'PEMSD4' #PEMSD4 or PEMSD8
|
|
MODEL = 'GCDE'
|
|
|
|
#get configuration
|
|
config_file = './{}_{}.conf'.format(DATASET, MODEL)
|
|
#print('Read configuration file: %s' % (config_file))
|
|
config = configparser.ConfigParser()
|
|
config.read(config_file)
|
|
|
|
from lib.metrics import MAE_torch
|
|
def masked_mae_loss(scaler, mask_value):
|
|
def loss(preds, labels):
|
|
if scaler:
|
|
preds = scaler.inverse_transform(preds)
|
|
labels = scaler.inverse_transform(labels)
|
|
mae = MAE_torch(pred=preds, true=labels, mask_value=mask_value)
|
|
return mae
|
|
return loss
|
|
|
|
#parser
|
|
args = argparse.ArgumentParser(description='arguments')
|
|
args.add_argument('--dataset', default=DATASET, type=str)
|
|
args.add_argument('--mode', default=Mode, type=str)
|
|
args.add_argument('--device', default=0, type=int, help='indices of GPUs')
|
|
args.add_argument('--debug', default=DEBUG, type=eval)
|
|
args.add_argument('--model', default=MODEL, type=str)
|
|
args.add_argument('--cuda', default=True, type=bool)
|
|
args.add_argument('--comment', default='', type=str)
|
|
|
|
|
|
#data
|
|
args.add_argument('--val_ratio', default=config['data']['val_ratio'], type=float)
|
|
args.add_argument('--test_ratio', default=config['data']['test_ratio'], type=float)
|
|
args.add_argument('--lag', default=config['data']['lag'], type=int)
|
|
args.add_argument('--horizon', default=config['data']['horizon'], type=int)
|
|
args.add_argument('--num_nodes', default=config['data']['num_nodes'], type=int)
|
|
args.add_argument('--tod', default=config['data']['tod'], type=eval)
|
|
args.add_argument('--normalizer', default=config['data']['normalizer'], type=str)
|
|
args.add_argument('--column_wise', default=config['data']['column_wise'], type=eval)
|
|
args.add_argument('--default_graph', default=config['data']['default_graph'], type=eval)
|
|
#model
|
|
args.add_argument('--model_type', default=config['model']['type'], type=str)
|
|
args.add_argument('--g_type', default=config['model']['g_type'], type=str)
|
|
args.add_argument('--input_dim', default=config['model']['input_dim'], type=int)
|
|
args.add_argument('--output_dim', default=config['model']['output_dim'], type=int)
|
|
args.add_argument('--embed_dim', default=config['model']['embed_dim'], type=int)
|
|
args.add_argument('--hid_dim', default=config['model']['hid_dim'], type=int)
|
|
args.add_argument('--hid_hid_dim', default=config['model']['hid_hid_dim'], type=int)
|
|
args.add_argument('--num_layers', default=config['model']['num_layers'], type=int)
|
|
args.add_argument('--cheb_k', default=config['model']['cheb_order'], type=int)
|
|
args.add_argument('--solver', default='rk4', type=str)
|
|
|
|
#train
|
|
args.add_argument('--loss_func', default=config['train']['loss_func'], type=str)
|
|
args.add_argument('--seed', default=config['train']['seed'], type=int)
|
|
args.add_argument('--batch_size', default=config['train']['batch_size'], type=int)
|
|
args.add_argument('--epochs', default=config['train']['epochs'], type=int)
|
|
args.add_argument('--lr_init', default=config['train']['lr_init'], type=float)
|
|
args.add_argument('--weight_decay', default=config['train']['weight_decay'], type=eval)
|
|
args.add_argument('--lr_decay', default=config['train']['lr_decay'], type=eval)
|
|
args.add_argument('--lr_decay_rate', default=config['train']['lr_decay_rate'], type=float)
|
|
args.add_argument('--lr_decay_step', default=config['train']['lr_decay_step'], type=str)
|
|
args.add_argument('--early_stop', default=config['train']['early_stop'], type=eval)
|
|
args.add_argument('--early_stop_patience', default=config['train']['early_stop_patience'], type=int)
|
|
args.add_argument('--grad_norm', default=config['train']['grad_norm'], type=eval)
|
|
args.add_argument('--max_grad_norm', default=config['train']['max_grad_norm'], type=int)
|
|
args.add_argument('--teacher_forcing', default=False, type=bool)
|
|
#args.add_argument('--tf_decay_steps', default=2000, type=int, help='teacher forcing decay steps')
|
|
args.add_argument('--real_value', default=config['train']['real_value'], type=eval, help = 'use real value for loss calculation')
|
|
|
|
args.add_argument('--missing_test', default=False, type=bool)
|
|
args.add_argument('--missing_rate', default=0.1, type=float)
|
|
|
|
#test
|
|
args.add_argument('--mae_thresh', default=config['test']['mae_thresh'], type=eval)
|
|
args.add_argument('--mape_thresh', default=config['test']['mape_thresh'], type=float)
|
|
args.add_argument('--model_path', default='', type=str)
|
|
#log
|
|
args.add_argument('--log_dir', default='../runs', type=str)
|
|
args.add_argument('--log_step', default=config['log']['log_step'], type=int)
|
|
args.add_argument('--plot', default=config['log']['plot'], type=eval)
|
|
args.add_argument('--tensorboard',action='store_true',help='tensorboard')
|
|
|
|
args = args.parse_args()
|
|
init_seed(args.seed)
|
|
|
|
GPU_NUM = args.device
|
|
device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
|
|
torch.cuda.set_device(device) # change allocation of current GPU
|
|
|
|
print(args)
|
|
|
|
#config log path
|
|
save_name = time.strftime("%m-%d-%Hh%Mm")+args.comment+"_"+ args.dataset+"_"+ args.model+"_"+ args.model_type+"_"+"embed{"+str(args.embed_dim)+"}"+"hid{"+str(args.hid_dim)+"}"+"hidhid{"+str(args.hid_hid_dim)+"}"+"lyrs{"+str(args.num_layers)+"}"+"lr{"+str(args.lr_init)+"}"+"wd{"+str(args.weight_decay)+"}"
|
|
path = '../runs'
|
|
|
|
log_dir = join(path, args.dataset, save_name)
|
|
args.log_dir = log_dir
|
|
if (os.path.exists(args.log_dir)):
|
|
print('has model save path')
|
|
else:
|
|
os.makedirs(args.log_dir)
|
|
|
|
if args.tensorboard:
|
|
w : SummaryWriter = SummaryWriter(args.log_dir)
|
|
else:
|
|
w = None
|
|
|
|
#init model
|
|
if args.model_type=='type1':
|
|
model, vector_field_f, vector_field_g = make_model(args)
|
|
elif args.model_type=='type1_temporal':
|
|
model, vector_field_f = make_model(args)
|
|
elif args.model_type=='type1_spatial':
|
|
model, vector_field_g = make_model(args)
|
|
else:
|
|
raise ValueError("Check args.model_type")
|
|
|
|
model = model.to(args.device)
|
|
|
|
if args.model_type=='type1_temporal':
|
|
vector_field_f = vector_field_f.to(args.device)
|
|
vector_field_g = None
|
|
elif args.model_type=='type1_spatial':
|
|
vector_field_f = None
|
|
vector_field_g = vector_field_g.to(args.device)
|
|
else:
|
|
vector_field_f = vector_field_f.to(args.device)
|
|
vector_field_g = vector_field_g.to(args.device)
|
|
|
|
print(model)
|
|
|
|
for p in model.parameters():
|
|
if p.dim() > 1:
|
|
nn.init.xavier_uniform_(p)
|
|
else:
|
|
nn.init.uniform_(p)
|
|
print_model_parameters(model, only_num=False)
|
|
|
|
#load dataset
|
|
train_loader, val_loader, test_loader, scaler, times = get_dataloader_cde(args,
|
|
normalizer=args.normalizer,
|
|
tod=args.tod, dow=False,
|
|
weather=False, single=False)
|
|
|
|
#init loss function, optimizer
|
|
if args.loss_func == 'mask_mae':
|
|
loss = masked_mae_loss(scaler, mask_value=0.0)
|
|
elif args.loss_func == 'mae':
|
|
loss = torch.nn.L1Loss().to(args.device)
|
|
elif args.loss_func == 'mse':
|
|
loss = torch.nn.MSELoss().to(args.device)
|
|
elif args.loss_func == 'huber_loss':
|
|
loss = torch.nn.HuberLoss(delta=1.0).to(args.device)
|
|
else:
|
|
raise ValueError
|
|
|
|
optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr_init,
|
|
weight_decay=args.weight_decay)
|
|
|
|
#learning rate decay
|
|
lr_scheduler = None
|
|
if args.lr_decay:
|
|
print('Applying learning rate decay.')
|
|
lr_decay_steps = [int(i) for i in list(args.lr_decay_step.split(','))]
|
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
|
|
milestones=lr_decay_steps,
|
|
gamma=args.lr_decay_rate)
|
|
|
|
#start training
|
|
trainer = Trainer(model, vector_field_f, vector_field_g, loss, optimizer, train_loader, val_loader, test_loader, scaler,
|
|
args, lr_scheduler, args.device, times,
|
|
w)
|
|
if args.mode == 'train':
|
|
trainer.train()
|
|
elif args.mode == 'test':
|
|
model.load_state_dict(torch.load('./pre-trained/{}.pth'.format(args.dataset)))
|
|
print("Load saved model")
|
|
trainer.test(model, trainer.args, test_loader, scaler, trainer.logger, times)
|
|
else:
|
|
raise ValueError
|