mirror of https://github.com/czzhangheng/STDEN.git
Optimize redundant comments
This commit is contained in:
parent
f8a85887a8
commit
e041038ed8
|
|
@ -3,7 +3,7 @@
|
||||||
This is the implementation of Spatio-temporal Differential Equation Network (STDEN) in the following paper:
|
This is the implementation of Spatio-temporal Differential Equation Network (STDEN) in the following paper:
|
||||||
Jiahao Ji, Jingyuan Wang, Zhe Jiang, Jiawei Jiang, and Hu Zhang, Towards Physics-guided Neural Networks for Traffic Flow Prediction, AAAI 2022.
|
Jiahao Ji, Jingyuan Wang, Zhe Jiang, Jiawei Jiang, and Hu Zhang, Towards Physics-guided Neural Networks for Traffic Flow Prediction, AAAI 2022.
|
||||||
|
|
||||||
Thanks [chnsh](https://github.com/chnsh/DCRNN_PyTorch) for the model training framework of this project.
|
The training framework of this project comes from [chnsh](https://github.com/chnsh/DCRNN_PyTorch). Thanks a lot :)
|
||||||
|
|
||||||
## Requirement
|
## Requirement
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def masked_mae_loss(y_pred, y_true):
|
def masked_mae_loss(y_pred, y_true):
|
||||||
# print('y_pred: ', y_pred.shape, 'y_true: ', y_true.shape)
|
|
||||||
y_true[y_true < 1e-4] = 0
|
y_true[y_true < 1e-4] = 0
|
||||||
mask = (y_true != 0).float()
|
mask = (y_true != 0).float()
|
||||||
mask /= mask.mean() # assign the sample weights of zeros to nonzero-values
|
mask /= mask.mean() # assign the sample weights of zeros to nonzero-values
|
||||||
|
|
@ -12,23 +11,19 @@ def masked_mae_loss(y_pred, y_true):
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
||||||
def masked_mape_loss(y_pred, y_true):
|
def masked_mape_loss(y_pred, y_true):
|
||||||
# print('y_pred: ', y_pred.shape, 'y_true: ', y_true.shape)
|
|
||||||
y_true[y_true < 1e-4] = 0
|
y_true[y_true < 1e-4] = 0
|
||||||
mask = (y_true != 0).float()
|
mask = (y_true != 0).float()
|
||||||
mask /= mask.mean()
|
mask /= mask.mean()
|
||||||
loss = torch.abs((y_pred - y_true) / y_true)
|
loss = torch.abs((y_pred - y_true) / y_true)
|
||||||
loss = loss * mask
|
loss = loss * mask
|
||||||
# trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
|
|
||||||
loss[loss != loss] = 0
|
loss[loss != loss] = 0
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
||||||
def masked_rmse_loss(y_pred, y_true):
|
def masked_rmse_loss(y_pred, y_true):
|
||||||
y_true[y_true < 1e-4] = 0
|
y_true[y_true < 1e-4] = 0
|
||||||
# print('y_pred: ', y_pred.shape, 'y_true: ', y_true.shape)
|
|
||||||
mask = (y_true != 0).float()
|
mask = (y_true != 0).float()
|
||||||
mask /= mask.mean()
|
mask /= mask.mean()
|
||||||
loss = torch.pow(y_pred - y_true, 2)
|
loss = torch.pow(y_pred - y_true, 2)
|
||||||
loss = loss * mask
|
loss = loss * mask
|
||||||
# trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
|
|
||||||
loss[loss != loss] = 0
|
loss[loss != loss] = 0
|
||||||
return torch.sqrt(loss.mean())
|
return torch.sqrt(loss.mean())
|
||||||
|
|
|
||||||
|
|
@ -2,16 +2,11 @@ import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import pickle
|
|
||||||
import scipy.sparse as sp
|
import scipy.sparse as sp
|
||||||
import sys
|
import sys
|
||||||
# import tensorflow as tf
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from scipy.sparse import linalg
|
|
||||||
|
|
||||||
|
|
||||||
class DataLoader(object):
|
class DataLoader(object):
|
||||||
def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False):
|
def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from lib import utils
|
from lib import utils
|
||||||
from model.stden_model import STDENModel
|
from model.stden_model import STDENModel
|
||||||
from lib.metrics import masked_mae_loss, masked_mape_loss, masked_mse_loss, masked_rmse_loss
|
from lib.metrics import masked_mae_loss, masked_mape_loss, masked_rmse_loss
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue