145 lines
4.2 KiB
Python
145 lines
4.2 KiB
Python
from data_provider.data_factory import data_provider
|
|
from utils.former_tools import vali, test, masked_mae, EarlyStopping
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
from models.repst import repst
|
|
import pickle
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import optim
|
|
import os
|
|
import time
|
|
import warnings
|
|
|
|
import argparse
|
|
import random
|
|
import logging
|
|
|
|
warnings.filterwarnings('ignore')
|
|
|
|
fix_seed = 2023
|
|
random.seed(fix_seed)
|
|
torch.manual_seed(fix_seed)
|
|
np.random.seed(fix_seed)
|
|
|
|
parser = argparse.ArgumentParser(description='RePST')
|
|
|
|
parser.add_argument('--device', type=str, default='cuda:0')
|
|
parser.add_argument('--checkpoints', type=str, default='./checkpoints/')
|
|
|
|
parser.add_argument('--root_path', type=str, default='path_to_data')
|
|
parser.add_argument('--data_path', type=str, default='dataset_name')
|
|
|
|
parser.add_argument('--pred_len', type=int, default=24)
|
|
parser.add_argument('--seq_len', type=int, default=24)
|
|
|
|
parser.add_argument('--decay_fac', type=float, default=0.75)
|
|
parser.add_argument('--learning_rate', type=float, default=0.002)
|
|
parser.add_argument('--batch_size', type=int, default=16)
|
|
parser.add_argument('--num_workers', type=int, default=10)
|
|
parser.add_argument('--train_epochs', type=int, default=100)
|
|
parser.add_argument('--patience', type=int, default=20)
|
|
|
|
parser.add_argument('--gpt_layers', type=int, default=9)
|
|
parser.add_argument('--d_model', type=int, default=64)
|
|
parser.add_argument('--n_heads', type=int, default=1)
|
|
parser.add_argument('--d_ff', type=int, default=128)
|
|
parser.add_argument('--dropout', type=float, default=0.2)
|
|
|
|
parser.add_argument('--patch_len', type=int, default=6)
|
|
|
|
parser.add_argument('--stride', type=int, default=7)
|
|
|
|
parser.add_argument('--tmax', type=int, default=5)
|
|
|
|
|
|
args = parser.parse_args()
|
|
device = torch.device(args.device)
|
|
|
|
logging.basicConfig(filename="./log/{}.log".format(args.data_path), level=logging.INFO)
|
|
logging.info(args)
|
|
|
|
rmses = []
|
|
maes = []
|
|
mapes = []
|
|
|
|
|
|
|
|
|
|
train_loader, vali_loader, test_loader = data_provider(args)
|
|
|
|
|
|
time_now = time.time()
|
|
|
|
model = repst(args, device).to(device)
|
|
early_stopping = EarlyStopping(patience=args.patience, verbose=True)
|
|
params = model.parameters()
|
|
model_optim = torch.optim.Adam(params, lr=args.learning_rate)
|
|
|
|
# class SMAPE(nn.Module):
|
|
# def __init__(self):
|
|
# super(SMAPE, self).__init__()
|
|
# def forward(self, pred, true):
|
|
# return torch.mean(200 * torch.abs(pred - true) / (torch.abs(pred) + torch.abs(true) + 1e-8))
|
|
# criterion = SMAPE()
|
|
criterion = nn.MSELoss()
|
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, T_max=args.tmax, eta_min=1e-8)
|
|
|
|
path = "./checkpoints/{}_{}_{}".format(args.data_path, args.gpt_layers, args.learning_rate)
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
for epoch in range(args.train_epochs):
|
|
|
|
iter_count = 0
|
|
train_loss = []
|
|
epoch_time = time.time()
|
|
train_loader.shuffle()
|
|
model_optim.zero_grad()
|
|
for i, (x, y) in enumerate(train_loader.get_iterator()):
|
|
|
|
iter_count += 1
|
|
|
|
x = x.to(device)
|
|
y = y.to(device)
|
|
|
|
outputs = model(x)
|
|
outputs = outputs[..., 0]
|
|
y = y[..., 0]
|
|
|
|
loss = criterion(outputs, y)
|
|
|
|
train_loss.append(loss.item())
|
|
|
|
if i % 100 == 0:
|
|
print("iters: {}, loss: {}, time_cost: {}".format(i + 1, np.average(train_loss[-100:]), time.time() - epoch_time))
|
|
logging.info("iters: {}, loss: {}, time_cost: {}".format(i + 1, np.average(train_loss[-100:]), time.time() - epoch_time))
|
|
|
|
loss.backward()
|
|
|
|
model_optim.step()
|
|
model_optim.zero_grad()
|
|
|
|
logging.info("Epoch: {} cost time: {}".format(epoch , time.time() - epoch_time))
|
|
print("Epoch: {} cost time: {}".format(epoch , time.time() - epoch_time))
|
|
|
|
train_loss = np.average(train_loss)
|
|
vali_loss = vali(model, vali_loader, criterion, args, device)
|
|
scheduler.step()
|
|
|
|
early_stopping(vali_loss, model, path)
|
|
|
|
if (epoch + 1) % 1 ==0:
|
|
|
|
print("------------------------------------")
|
|
logging.info("------------------------------------")
|
|
mae, mape, rmse = test(model, test_loader, args, device)
|
|
log = 'On average over all horizons, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}'
|
|
logging.info(log.format(mae,mape,rmse))
|
|
print(log.format(mae,mape,rmse))
|
|
|