267 lines
8.5 KiB
Python
Executable File
267 lines
8.5 KiB
Python
Executable File
import os
|
|
import sys
|
|
|
|
file_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
print(file_dir)
|
|
sys.path.append(file_dir)
|
|
|
|
import torch
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
import argparse
|
|
import configparser
|
|
import time
|
|
|
|
from model.BasicTrainer_cde import Trainer
|
|
from lib.TrainInits import init_seed
|
|
from lib.dataloader import get_dataloader_cde
|
|
from lib.TrainInits import print_model_parameters
|
|
import os
|
|
from os.path import join
|
|
from Make_model import make_model
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
# *************************************************************************#
|
|
Mode = "train"
|
|
DEBUG = "False"
|
|
DATASET = "PEMSD4" # PEMSD4 or PEMSD8
|
|
MODEL = "GCDE"
|
|
|
|
# get configuration
|
|
config_file = "./{}_{}.conf".format(DATASET, MODEL)
|
|
# print('Read configuration file: %s' % (config_file))
|
|
config = configparser.ConfigParser()
|
|
config.read(config_file)
|
|
|
|
from lib.metrics import MAE_torch
|
|
|
|
|
|
def masked_mae_loss(scaler, mask_value):
|
|
def loss(preds, labels):
|
|
if scaler:
|
|
preds = scaler.inverse_transform(preds)
|
|
labels = scaler.inverse_transform(labels)
|
|
mae = MAE_torch(pred=preds, true=labels, mask_value=mask_value)
|
|
return mae
|
|
|
|
return loss
|
|
|
|
|
|
# parser
|
|
args = argparse.ArgumentParser(description="arguments")
|
|
args.add_argument("--dataset", default=DATASET, type=str)
|
|
args.add_argument("--mode", default=Mode, type=str)
|
|
args.add_argument("--device", default=0, type=int, help="indices of GPUs")
|
|
args.add_argument("--debug", default=DEBUG, type=eval)
|
|
args.add_argument("--model", default=MODEL, type=str)
|
|
args.add_argument("--cuda", default=True, type=bool)
|
|
args.add_argument("--comment", default="", type=str)
|
|
|
|
|
|
# data
|
|
args.add_argument("--val_ratio", default=config["data"]["val_ratio"], type=float)
|
|
args.add_argument("--test_ratio", default=config["data"]["test_ratio"], type=float)
|
|
args.add_argument("--lag", default=config["data"]["lag"], type=int)
|
|
args.add_argument("--horizon", default=config["data"]["horizon"], type=int)
|
|
args.add_argument("--num_nodes", default=config["data"]["num_nodes"], type=int)
|
|
args.add_argument("--tod", default=config["data"]["tod"], type=eval)
|
|
args.add_argument("--normalizer", default=config["data"]["normalizer"], type=str)
|
|
args.add_argument("--column_wise", default=config["data"]["column_wise"], type=eval)
|
|
args.add_argument("--default_graph", default=config["data"]["default_graph"], type=eval)
|
|
# model
|
|
args.add_argument("--model_type", default=config["model"]["type"], type=str)
|
|
args.add_argument("--g_type", default=config["model"]["g_type"], type=str)
|
|
args.add_argument("--input_dim", default=config["model"]["input_dim"], type=int)
|
|
args.add_argument("--output_dim", default=config["model"]["output_dim"], type=int)
|
|
args.add_argument("--embed_dim", default=config["model"]["embed_dim"], type=int)
|
|
args.add_argument("--hid_dim", default=config["model"]["hid_dim"], type=int)
|
|
args.add_argument("--hid_hid_dim", default=config["model"]["hid_hid_dim"], type=int)
|
|
args.add_argument("--num_layers", default=config["model"]["num_layers"], type=int)
|
|
args.add_argument("--cheb_k", default=config["model"]["cheb_order"], type=int)
|
|
args.add_argument("--solver", default="rk4", type=str)
|
|
|
|
# train
|
|
args.add_argument("--loss_func", default=config["train"]["loss_func"], type=str)
|
|
args.add_argument("--seed", default=config["train"]["seed"], type=int)
|
|
args.add_argument("--batch_size", default=config["train"]["batch_size"], type=int)
|
|
args.add_argument("--epochs", default=config["train"]["epochs"], type=int)
|
|
args.add_argument("--lr_init", default=config["train"]["lr_init"], type=float)
|
|
args.add_argument("--weight_decay", default=config["train"]["weight_decay"], type=eval)
|
|
args.add_argument("--lr_decay", default=config["train"]["lr_decay"], type=eval)
|
|
args.add_argument(
|
|
"--lr_decay_rate", default=config["train"]["lr_decay_rate"], type=float
|
|
)
|
|
args.add_argument("--lr_decay_step", default=config["train"]["lr_decay_step"], type=str)
|
|
args.add_argument("--early_stop", default=config["train"]["early_stop"], type=eval)
|
|
args.add_argument(
|
|
"--early_stop_patience", default=config["train"]["early_stop_patience"], type=int
|
|
)
|
|
args.add_argument("--grad_norm", default=config["train"]["grad_norm"], type=eval)
|
|
args.add_argument("--max_grad_norm", default=config["train"]["max_grad_norm"], type=int)
|
|
args.add_argument("--teacher_forcing", default=False, type=bool)
|
|
# args.add_argument('--tf_decay_steps', default=2000, type=int, help='teacher forcing decay steps')
|
|
args.add_argument(
|
|
"--real_value",
|
|
default=config["train"]["real_value"],
|
|
type=eval,
|
|
help="use real value for loss calculation",
|
|
)
|
|
|
|
args.add_argument("--missing_test", default=False, type=bool)
|
|
args.add_argument("--missing_rate", default=0.1, type=float)
|
|
|
|
# test
|
|
args.add_argument("--mae_thresh", default=config["test"]["mae_thresh"], type=eval)
|
|
args.add_argument("--mape_thresh", default=config["test"]["mape_thresh"], type=float)
|
|
args.add_argument("--model_path", default="", type=str)
|
|
# log
|
|
args.add_argument("--log_dir", default="../runs", type=str)
|
|
args.add_argument("--log_step", default=config["log"]["log_step"], type=int)
|
|
args.add_argument("--plot", default=config["log"]["plot"], type=eval)
|
|
args.add_argument("--tensorboard", action="store_true", help="tensorboard")
|
|
|
|
args = args.parse_args()
|
|
init_seed(args.seed)
|
|
|
|
GPU_NUM = args.device
|
|
device = torch.device(f"cuda:{GPU_NUM}" if torch.cuda.is_available() else "cpu")
|
|
torch.cuda.set_device(device) # change allocation of current GPU
|
|
|
|
print(args)
|
|
|
|
# config log path
|
|
save_name = (
|
|
time.strftime("%m-%d-%Hh%Mm")
|
|
+ args.comment
|
|
+ "_"
|
|
+ args.dataset
|
|
+ "_"
|
|
+ args.model
|
|
+ "_"
|
|
+ args.model_type
|
|
+ "_"
|
|
+ "embed{"
|
|
+ str(args.embed_dim)
|
|
+ "}"
|
|
+ "hid{"
|
|
+ str(args.hid_dim)
|
|
+ "}"
|
|
+ "hidhid{"
|
|
+ str(args.hid_hid_dim)
|
|
+ "}"
|
|
+ "lyrs{"
|
|
+ str(args.num_layers)
|
|
+ "}"
|
|
+ "lr{"
|
|
+ str(args.lr_init)
|
|
+ "}"
|
|
+ "wd{"
|
|
+ str(args.weight_decay)
|
|
+ "}"
|
|
)
|
|
path = "../runs"
|
|
|
|
log_dir = join(path, args.dataset, save_name)
|
|
args.log_dir = log_dir
|
|
if os.path.exists(args.log_dir):
|
|
print("has model save path")
|
|
else:
|
|
os.makedirs(args.log_dir)
|
|
|
|
if args.tensorboard:
|
|
w: SummaryWriter = SummaryWriter(args.log_dir)
|
|
else:
|
|
w = None
|
|
|
|
# init model
|
|
if args.model_type == "type1":
|
|
model, vector_field_f, vector_field_g = make_model(args)
|
|
elif args.model_type == "type1_temporal":
|
|
model, vector_field_f = make_model(args)
|
|
elif args.model_type == "type1_spatial":
|
|
model, vector_field_g = make_model(args)
|
|
else:
|
|
raise ValueError("Check args.model_type")
|
|
|
|
model = model.to(args.device)
|
|
|
|
if args.model_type == "type1_temporal":
|
|
vector_field_f = vector_field_f.to(args.device)
|
|
vector_field_g = None
|
|
elif args.model_type == "type1_spatial":
|
|
vector_field_f = None
|
|
vector_field_g = vector_field_g.to(args.device)
|
|
else:
|
|
vector_field_f = vector_field_f.to(args.device)
|
|
vector_field_g = vector_field_g.to(args.device)
|
|
|
|
print(model)
|
|
|
|
for p in model.parameters():
|
|
if p.dim() > 1:
|
|
nn.init.xavier_uniform_(p)
|
|
else:
|
|
nn.init.uniform_(p)
|
|
print_model_parameters(model, only_num=False)
|
|
|
|
# load dataset
|
|
train_loader, val_loader, test_loader, scaler, times = get_dataloader_cde(
|
|
args,
|
|
normalizer=args.normalizer,
|
|
tod=args.tod,
|
|
dow=False,
|
|
weather=False,
|
|
single=False,
|
|
)
|
|
|
|
# init loss function, optimizer
|
|
if args.loss_func == "mask_mae":
|
|
loss = masked_mae_loss(scaler, mask_value=0.0)
|
|
elif args.loss_func == "mae":
|
|
loss = torch.nn.L1Loss().to(args.device)
|
|
elif args.loss_func == "mse":
|
|
loss = torch.nn.MSELoss().to(args.device)
|
|
elif args.loss_func == "huber_loss":
|
|
loss = torch.nn.HuberLoss(delta=1.0).to(args.device)
|
|
else:
|
|
raise ValueError
|
|
|
|
optimizer = torch.optim.Adam(
|
|
params=model.parameters(), lr=args.lr_init, weight_decay=args.weight_decay
|
|
)
|
|
|
|
# learning rate decay
|
|
lr_scheduler = None
|
|
if args.lr_decay:
|
|
print("Applying learning rate decay.")
|
|
lr_decay_steps = [int(i) for i in list(args.lr_decay_step.split(","))]
|
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
optimizer=optimizer, milestones=lr_decay_steps, gamma=args.lr_decay_rate
|
|
)
|
|
|
|
# start training
|
|
trainer = Trainer(
|
|
model,
|
|
vector_field_f,
|
|
vector_field_g,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
args.device,
|
|
times,
|
|
w,
|
|
)
|
|
if args.mode == "train":
|
|
trainer.train()
|
|
elif args.mode == "test":
|
|
model.load_state_dict(torch.load("./pre-trained/{}.pth".format(args.dataset)))
|
|
print("Load saved model")
|
|
trainer.test(model, trainer.args, test_loader, scaler, trainer.logger, times)
|
|
else:
|
|
raise ValueError
|