mirror of https://github.com/czzhangheng/STDEN.git
Optimize redundant comments
This commit is contained in:
parent
f8a85887a8
commit
e041038ed8
|
|
@ -3,7 +3,7 @@
|
|||
This is the implementation of Spatio-temporal Differential Equation Network (STDEN) in the following paper:
|
||||
Jiahao Ji, Jingyuan Wang, Zhe Jiang, Jiawei Jiang, and Hu Zhang, Towards Physics-guided Neural Networks for Traffic Flow Prediction, AAAI 2022.
|
||||
|
||||
Thanks [chnsh](https://github.com/chnsh/DCRNN_PyTorch) for the model training framework of this project.
|
||||
The training framework of this project comes from [chnsh](https://github.com/chnsh/DCRNN_PyTorch). Thanks a lot :)
|
||||
|
||||
## Requirement
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import torch
|
||||
|
||||
def masked_mae_loss(y_pred, y_true):
|
||||
# print('y_pred: ', y_pred.shape, 'y_true: ', y_true.shape)
|
||||
y_true[y_true < 1e-4] = 0
|
||||
mask = (y_true != 0).float()
|
||||
mask /= mask.mean() # assign the sample weights of zeros to nonzero-values
|
||||
|
|
@ -12,23 +11,19 @@ def masked_mae_loss(y_pred, y_true):
|
|||
return loss.mean()
|
||||
|
||||
def masked_mape_loss(y_pred, y_true):
|
||||
# print('y_pred: ', y_pred.shape, 'y_true: ', y_true.shape)
|
||||
y_true[y_true < 1e-4] = 0
|
||||
mask = (y_true != 0).float()
|
||||
mask /= mask.mean()
|
||||
loss = torch.abs((y_pred - y_true) / y_true)
|
||||
loss = loss * mask
|
||||
# trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
|
||||
loss[loss != loss] = 0
|
||||
return loss.mean()
|
||||
|
||||
def masked_rmse_loss(y_pred, y_true):
|
||||
y_true[y_true < 1e-4] = 0
|
||||
# print('y_pred: ', y_pred.shape, 'y_true: ', y_true.shape)
|
||||
mask = (y_true != 0).float()
|
||||
mask /= mask.mean()
|
||||
loss = torch.pow(y_pred - y_true, 2)
|
||||
loss = loss * mask
|
||||
# trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
|
||||
loss[loss != loss] = 0
|
||||
return torch.sqrt(loss.mean())
|
||||
|
|
|
|||
|
|
@ -2,16 +2,11 @@ import logging
|
|||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
import pickle
|
||||
import scipy.sparse as sp
|
||||
import sys
|
||||
# import tensorflow as tf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from scipy.sparse import linalg
|
||||
|
||||
|
||||
class DataLoader(object):
|
||||
def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||
|
||||
from lib import utils
|
||||
from model.stden_model import STDENModel
|
||||
from lib.metrics import masked_mae_loss, masked_mape_loss, masked_mse_loss, masked_rmse_loss
|
||||
from lib.metrics import masked_mae_loss, masked_mape_loss, masked_rmse_loss
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue