Compare commits
No commits in common. "8e53d25ab1c702dd6a0428aa2bb1f3d9788102cb" and "9b3bb44552a4b5f4e1cf2da81752c19704f80e00" have entirely different histories.
8e53d25ab1
...
9b3bb44552
103
README.md
103
README.md
|
|
@ -22,23 +22,108 @@ pip install -r requirements.txt
|
||||||
pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook
|
pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook
|
||||||
```
|
```
|
||||||
|
|
||||||
# 准备GPT预训练权重
|
|
||||||
|
|
||||||
需要海外网络,如果没有海外网络,手动下载后上传。
|
|
||||||
|
|
||||||
GPT-2文件夹内应该有两个文件:`{config.json, pytorch_model.bin}`
|
# 快速开始(暂时弃用)
|
||||||
|
|
||||||
|
参考baseline.ipynb中的命令执行,或者使用下面的命令:(请确保当前目录为TrafficWheel)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
mkdir GPT-2
|
python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0}
|
||||||
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./GPT-2/config.json
|
|
||||||
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./GPT-2/pytorch_model.bin
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- model_name: 目前支持:DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN、STAWnet
|
||||||
|
- dataset_name目前支持:PEMSD3,PEMSD4、PEMSD7、PEMSD8
|
||||||
|
- mode:train为训练模型,test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。
|
||||||
|
- device: 支持'cpu'、'cuda:0'、‘cuda:1’ ... 取决于机器卡数
|
||||||
|
|
||||||
# 跑REPST
|
run.py会自动完成数据集下载、模型训练/评估工作。
|
||||||
第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8/PEMS-BAY。
|
|
||||||
|
:warning:现有的模型性能数据存放在[Result.xlsx](./Result.xlsx),不必浪费资源再跑一遍。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 测试模型
|
||||||
|
|
||||||
|
在实验结束后,模型的存档文件会被保存在 `experiments/dataset/训练时间 `文件夹下,共有4个文件。
|
||||||
|
|
||||||
|
- best_model.pth 保存了使验证集效果最好的checkpoint
|
||||||
|
- best_test_model.pth 保存了使测试集的效果最好的checkpoint
|
||||||
|
- DATASET.yaml 存放了训练模型时所使用的参数
|
||||||
|
- run.log 记录了训练日志。
|
||||||
|
|
||||||
|
可以创建`pre-train/{dataset_name}`文件夹,把整个文件夹的内容拷贝到`experiments/dataset/训练时间 `文件夹下的内容全部拷贝到`pre-train/{dataset_name}`里面。之后就可以在命令中调用` --model test`进行测试。
|
||||||
|
|
||||||
|
:warning:注意,请及时删除experiments文件夹中不必要的文件,要不整个文件夹会越堆越大。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 更改配置
|
||||||
|
|
||||||
|
在config文件夹中,存放了每个模型的配置文件。每个数据集单独配置,使用yaml格式。
|
||||||
|
|
||||||
|
你需要找到对应模型的参数进行修改。
|
||||||
|
|
||||||
|
配置文件分为五个部分:[data, model, train, test, log]
|
||||||
|
|
||||||
|
- data一般不用改,存放了模型的节点数,预测窗口,历史窗口等信息
|
||||||
|
- model中的参数要结合代码和论文看,一般会给出推荐配置。
|
||||||
|
- train调整模型的训练细节,包括batch size,学习率,batch_norm等。
|
||||||
|
|
||||||
|
一般不建议对基线参数进行修改,按默认跑是最稳定的。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 开发模型
|
||||||
|
|
||||||
|
首先你需要创建一个开发分支dev,并切换到开发分支
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python run.py --config ./config/REPST/PEMS-BAY.yaml
|
git switch -c dev
|
||||||
```
|
```
|
||||||
|
|
||||||
|
参考 [模型迁移教程](./transfer_guide.md) 迁移模型到TrafficWheel中。
|
||||||
|
|
||||||
|
提交更改。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add .
|
||||||
|
git commit -m "Commit message"
|
||||||
|
```
|
||||||
|
|
||||||
|
推送到远程仓库(需要找我注册账号)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git push origin dev
|
||||||
|
```
|
||||||
|
|
||||||
|
模型开发完成后,需要合并到main分支,在[这里](https://github.zhang-heng.com/czzhangheng/TrafficWheel/pulls)提交pull request。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 已知问题
|
||||||
|
|
||||||
|
目前,实测以下模型性能与原报告相比指标偏高:ARIMA、TCN、DCRNN
|
||||||
|
|
||||||
|
STGCN在载入图时会有未知warning
|
||||||
|
|
||||||
|
以下模型由于没有源码暂未实现:HA、VAR、FC-LSTM、GRU-ED
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 源代码及论文
|
||||||
|
|
||||||
|
| 论文 | 代码 |
|
||||||
|
| ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||||
|
| [HierAttnLSTM](https://arxiv.org/abs/2201.05760v4) | [代码](https://github.com/TeRyZh/Network-Level-Travel-Prediction-Hierarchical-Attention-LSTM) |
|
||||||
|
| [DSANET](https://dl.acm.org/doi/10.1145/3357384.3358132) | [代码](https://github.com/bighuang624/DSANet) |
|
||||||
|
| [STGCN](https://arxiv.org/abs/1709.04875) | [代码](https://github.com/hazdzz/STGCN) |
|
||||||
|
| [DCRNN](https://arxiv.org/abs/1707.01926) | [代码](https://github.com/chnsh/DCRNN_PyTorch) |
|
||||||
|
| [GraphWaveNet](https://arxiv.org/pdf/1906.00121.pdf) | [代码](https://github.com/SGT-LIM/GraphWavenet) |
|
||||||
|
| [STSGCN](https://aaai.org/ojs/index.php/AAAI/article/view/5438/5294) | [代码](https://github.com/SmallNana/STSGCN_Pytorch) |
|
||||||
|
| [AGCRN](https://arxiv.org/pdf/2007.02842) | [代码](https://github.com/LeiBAI/AGCRN) |
|
||||||
|
| [STFGNN](https://arxiv.org/abs/2012.09641) | [代码](https://github.com/lwm412/STFGNN-Pytorch) |
|
||||||
|
| [STGODE](https://arxiv.org/abs/2106.12931) | [代码](https://github.com/square-coder/STGODE) |
|
||||||
|
| [STG-NCDE](https://arxiv.org/abs/2112.03558) | [代码](https://github.com/jeongwhanchoi/STG-NCDE) |
|
||||||
|
| [DDGCRN](https://www.sciencedirect.com/science/article/abs/pii/S0031320323003710) | [代码](https://github.com/wengwenchao123/DDGCRN) |
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,8 @@ data:
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
default_graph: true
|
default_graph: true
|
||||||
horizon: 24
|
horizon: 12
|
||||||
lag: 24
|
lag: 12
|
||||||
normalizer: std
|
normalizer: std
|
||||||
num_nodes: 325
|
num_nodes: 325
|
||||||
steps_per_day: 288
|
steps_per_day: 288
|
||||||
|
|
@ -23,8 +23,8 @@ data:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
||||||
model:
|
model:
|
||||||
pred_len: 24
|
pred_len: 12
|
||||||
seq_len: 24
|
seq_len: 12
|
||||||
patch_len: 6
|
patch_len: 6
|
||||||
stride: 7
|
stride: 7
|
||||||
dropout: 0.2
|
dropout: 0.2
|
||||||
|
|
@ -33,7 +33,6 @@ model:
|
||||||
gpt_path: ./GPT-2
|
gpt_path: ./GPT-2
|
||||||
d_model: 64
|
d_model: 64
|
||||||
n_heads: 1
|
n_heads: 1
|
||||||
input_dim: 1
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
|
||||||
|
|
@ -1,66 +0,0 @@
|
||||||
basic:
|
|
||||||
dataset: "PEMS-BAY"
|
|
||||||
mode: "train"
|
|
||||||
device: "cuda:0"
|
|
||||||
model: "STID"
|
|
||||||
|
|
||||||
data:
|
|
||||||
num_nodes: 325
|
|
||||||
lag: 24
|
|
||||||
horizon: 24
|
|
||||||
val_ratio: 0.2
|
|
||||||
test_ratio: 0.2
|
|
||||||
tod: False
|
|
||||||
normalizer: std
|
|
||||||
column_wise: False
|
|
||||||
default_graph: True
|
|
||||||
add_time_in_day: True
|
|
||||||
add_day_in_week: True
|
|
||||||
steps_per_day: 288
|
|
||||||
days_per_week: 7
|
|
||||||
input_dim: 1
|
|
||||||
output_dim: 1
|
|
||||||
batch_size: 64
|
|
||||||
|
|
||||||
model:
|
|
||||||
input_dim: 3
|
|
||||||
output_dim: 1
|
|
||||||
history: 24
|
|
||||||
horizon: 24
|
|
||||||
num_nodes: 325
|
|
||||||
input_len: 24
|
|
||||||
embed_dim: 32
|
|
||||||
output_len: 24
|
|
||||||
num_layer: 3
|
|
||||||
if_node: True
|
|
||||||
node_dim: 32
|
|
||||||
if_T_i_D: True
|
|
||||||
if_D_i_W: True
|
|
||||||
temp_dim_tid: 32
|
|
||||||
temp_dim_diw: 32
|
|
||||||
time_of_day_size: 288
|
|
||||||
day_of_week_size: 7
|
|
||||||
batch_size: 64
|
|
||||||
|
|
||||||
|
|
||||||
train:
|
|
||||||
loss_func: mae
|
|
||||||
seed: 1
|
|
||||||
batch_size: 64
|
|
||||||
epochs: 300
|
|
||||||
lr_init: 0.002
|
|
||||||
weight_decay: 0.0001
|
|
||||||
lr_decay: False
|
|
||||||
lr_decay_rate: 0.3
|
|
||||||
lr_decay_step: "1,50,80"
|
|
||||||
early_stop: True
|
|
||||||
early_stop_patience: 15
|
|
||||||
grad_norm: False
|
|
||||||
max_grad_norm: 5
|
|
||||||
real_value: True
|
|
||||||
debug: true
|
|
||||||
output_dim: 1
|
|
||||||
mae_thresh: null
|
|
||||||
mape_thresh: 0.0
|
|
||||||
log_step: 200
|
|
||||||
plot: False
|
|
||||||
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from model.DCRNN.dcrnn_cell import DCGRUCell
|
from model.DCRNN.dcrnn_cell import DCGRUCell
|
||||||
from utils.get_adj import get_adj
|
from data.get_adj import get_adj
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqAttrs:
|
class Seq2SeqAttrs:
|
||||||
|
|
|
||||||
|
|
@ -15,13 +15,12 @@ class ReplicationPad1d(nn.Module):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
class TokenEmbedding(nn.Module):
|
class TokenEmbedding(nn.Module):
|
||||||
def __init__(self, c_in, d_model, patch_num, input_dim):
|
def __init__(self, c_in, d_model):
|
||||||
super(TokenEmbedding, self).__init__()
|
super(TokenEmbedding, self).__init__()
|
||||||
padding = 1
|
padding = 1
|
||||||
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
||||||
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
||||||
|
self.confusion_layer = nn.Linear(2, 1)
|
||||||
self.confusion_layer = nn.Linear(patch_num * input_dim, 1)
|
|
||||||
# if air_quality
|
# if air_quality
|
||||||
# self.confusion_layer = nn.Linear(42, 1)
|
# self.confusion_layer = nn.Linear(42, 1)
|
||||||
|
|
||||||
|
|
@ -35,18 +34,19 @@ class TokenEmbedding(nn.Module):
|
||||||
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
||||||
# 768,64,25
|
# 768,64,25
|
||||||
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
||||||
|
|
||||||
x = self.confusion_layer(x)
|
x = self.confusion_layer(x)
|
||||||
return x.reshape(b, n, -1)
|
return x.reshape(b, n, -1)
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbedding(nn.Module):
|
class PatchEmbedding(nn.Module):
|
||||||
def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim):
|
def __init__(self, d_model, patch_len, stride, dropout):
|
||||||
super(PatchEmbedding, self).__init__()
|
super(PatchEmbedding, self).__init__()
|
||||||
# Patching
|
# Patching
|
||||||
self.patch_len = patch_len
|
self.patch_len = patch_len
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.padding_patch_layer = ReplicationPad1d((0, stride))
|
self.padding_patch_layer = ReplicationPad1d((0, stride))
|
||||||
self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim)
|
self.value_embedding = TokenEmbedding(patch_len, d_model)
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ class repst(nn.Module):
|
||||||
self.pred_len = configs['pred_len']
|
self.pred_len = configs['pred_len']
|
||||||
self.seq_len = configs['seq_len']
|
self.seq_len = configs['seq_len']
|
||||||
self.patch_len = configs['patch_len']
|
self.patch_len = configs['patch_len']
|
||||||
self.input_dim = configs['input_dim']
|
|
||||||
self.stride = configs['stride']
|
self.stride = configs['stride']
|
||||||
self.dropout = configs['dropout']
|
self.dropout = configs['dropout']
|
||||||
self.gpt_layers = configs['gpt_layers']
|
self.gpt_layers = configs['gpt_layers']
|
||||||
|
|
@ -29,7 +28,7 @@ class repst(nn.Module):
|
||||||
self.head_nf = self.d_ff * self.patch_nums
|
self.head_nf = self.d_ff * self.patch_nums
|
||||||
|
|
||||||
# 64,6,7,0.2
|
# 64,6,7,0.2
|
||||||
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
|
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout)
|
||||||
|
|
||||||
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
|
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
|
||||||
self.gpts.h = self.gpts.h[:self.gpt_layers]
|
self.gpts.h = self.gpts.h[:self.gpt_layers]
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from utils.get_adj import get_adj
|
from data.get_adj import get_adj
|
||||||
|
|
||||||
|
|
||||||
class gcn_operation(nn.Module):
|
class gcn_operation(nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from model.STGCN import layers
|
from model.STGCN import layers
|
||||||
from utils.get_adj import get_gso
|
from data.get_adj import get_gso
|
||||||
|
|
||||||
|
|
||||||
class STGCNChebGraphConv(nn.Module):
|
class STGCNChebGraphConv(nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import init
|
from torch.nn import init
|
||||||
from utils.get_adj import get_adj
|
from data.get_adj import get_adj
|
||||||
import numbers
|
import numbers
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from model.ST-SSL.models import STSSL
|
from model.ST-SSL.models import STSSL
|
||||||
from model.ST-SSL.layers
|
from model.ST-SSL.layers import STEncoder, MLP
|
||||||
from utils.get_adj import get_gso
|
from data.get_adj import get_gso
|
||||||
|
|
||||||
class STSSLModel(nn.Module):
|
class STSSLModel(nn.Module):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from utils.get_adj import get_gso
|
from data.get_adj import get_gso
|
||||||
|
|
||||||
|
|
||||||
class STSSLModel(nn.Module):
|
class STSSLModel(nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,4 @@ torchaudio
|
||||||
torchdiffeq
|
torchdiffeq
|
||||||
fastdtw
|
fastdtw
|
||||||
notebook
|
notebook
|
||||||
torchcde
|
torchcde
|
||||||
einops
|
|
||||||
transformers
|
|
||||||
9
run.py
9
run.py
|
|
@ -1,7 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from utils.Download_data import check_and_download_data
|
|
||||||
data_complete = check_and_download_data()
|
|
||||||
assert data_complete is not None, "数据集下载失败,请重试!"
|
|
||||||
|
|
||||||
# import time
|
# import time
|
||||||
from config.args_parser import parse_args
|
from config.args_parser import parse_args
|
||||||
|
|
@ -60,5 +58,10 @@ def main():
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unsupported mode: {args['basic']['mode']}")
|
raise ValueError(f"Unsupported mode: {args['basic']['mode']}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
from utils.Download_data import check_and_download_data
|
||||||
|
|
||||||
|
data_complete = check_and_download_data()
|
||||||
|
assert data_complete is not None, "数据集下载失败,请重试!"
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ import psutil
|
||||||
import torch
|
import torch
|
||||||
from utils.logger import get_logger
|
from utils.logger import get_logger
|
||||||
from utils.loss_function import all_metrics
|
from utils.loss_function import all_metrics
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingStats:
|
class TrainingStats:
|
||||||
|
|
@ -149,8 +148,7 @@ class Trainer:
|
||||||
y_pred, y_true = [], []
|
y_pred, y_true = [], []
|
||||||
|
|
||||||
with torch.set_grad_enabled(optimizer_step):
|
with torch.set_grad_enabled(optimizer_step):
|
||||||
progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}")
|
for batch_idx, (data, target) in enumerate(dataloader):
|
||||||
for batch_idx, (data, target) in progress_bar:
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
label = target[..., : self.args["output_dim"]]
|
label = target[..., : self.args["output_dim"]]
|
||||||
|
|
@ -177,9 +175,6 @@ class Trainer:
|
||||||
y_pred.append(output.detach().cpu())
|
y_pred.append(output.detach().cpu())
|
||||||
y_true.append(label.detach().cpu())
|
y_true.append(label.detach().cpu())
|
||||||
|
|
||||||
# Update progress bar with the current loss
|
|
||||||
progress_bar.set_postfix(loss=loss.item())
|
|
||||||
|
|
||||||
y_pred = torch.cat(y_pred, dim=0)
|
y_pred = torch.cat(y_pred, dim=0)
|
||||||
y_true = torch.cat(y_true, dim=0)
|
y_true = torch.cat(y_true, dim=0)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,11 @@ def check_and_download_data():
|
||||||
missing_adj = True
|
missing_adj = True
|
||||||
missing_main_files = True
|
missing_main_files = True
|
||||||
else:
|
else:
|
||||||
|
# 检查根目录下的 get_adj.py 文件
|
||||||
|
if "get_adj.py" not in os.listdir(data_dir):
|
||||||
|
# print(f"根目录下缺少文件 get_adj.py。")
|
||||||
|
missing_adj = True
|
||||||
|
|
||||||
# 遍历预期的文件结构
|
# 遍历预期的文件结构
|
||||||
for subfolder, expected_files in expected_structure.items():
|
for subfolder, expected_files in expected_structure.items():
|
||||||
subfolder_path = os.path.join(data_dir, subfolder)
|
subfolder_path = os.path.join(data_dir, subfolder)
|
||||||
|
|
@ -95,6 +100,8 @@ def check_and_download_data():
|
||||||
|
|
||||||
rearrange_dir()
|
rearrange_dir()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
218
utils/get_adj.py
218
utils/get_adj.py
|
|
@ -1,218 +0,0 @@
|
||||||
import csv
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
from scipy.sparse.linalg import norm
|
|
||||||
import scipy.sparse as sp
|
|
||||||
import torch
|
|
||||||
|
|
||||||
def get_adj(args):
|
|
||||||
dataset_path = './data'
|
|
||||||
match args['num_nodes']:
|
|
||||||
case 358:
|
|
||||||
dataset_name = 'PEMS03'
|
|
||||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
|
|
||||||
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
|
|
||||||
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
|
|
||||||
case 307:
|
|
||||||
dataset_name = 'PEMS04'
|
|
||||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
|
|
||||||
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
|
||||||
case 883:
|
|
||||||
dataset_name = 'PEMS07'
|
|
||||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
|
|
||||||
A, adj = load_adj(args['num_nodes'], adj_path)
|
|
||||||
case 170:
|
|
||||||
dataset_name = 'PEMS08'
|
|
||||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
|
|
||||||
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
|
||||||
|
|
||||||
return adj
|
|
||||||
|
|
||||||
def get_gso(args):
|
|
||||||
dataset_path = './data'
|
|
||||||
match args['num_nodes']:
|
|
||||||
case 358:
|
|
||||||
dataset_name = 'PEMS03'
|
|
||||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
|
|
||||||
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
|
|
||||||
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
|
|
||||||
case 307:
|
|
||||||
dataset_name = 'PEMS04'
|
|
||||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
|
|
||||||
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
|
||||||
case 883:
|
|
||||||
dataset_name = 'PEMS07'
|
|
||||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
|
|
||||||
A, adj = load_adj(args['num_nodes'], adj_path)
|
|
||||||
case 170:
|
|
||||||
dataset_name = 'PEMS08'
|
|
||||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
|
|
||||||
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
|
||||||
|
|
||||||
gso = calc_gso(adj, args['gso_type'])
|
|
||||||
if args['graph_conv_type'] == 'cheb_graph_conv':
|
|
||||||
gso = calc_chebynet_gso(gso)
|
|
||||||
gso = gso.toarray()
|
|
||||||
gso = gso.astype(dtype=np.float32)
|
|
||||||
gso = torch.from_numpy(gso).to(args['device'])
|
|
||||||
return gso
|
|
||||||
|
|
||||||
def load_adj(num_nodes, adj_path, id_filename=None, std=False):
|
|
||||||
'''
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
adj_path: str, path of the csv file contains edges information
|
|
||||||
num_nodes: int, the number of vertices
|
|
||||||
id_filename: str, optional, path of the file containing node IDs (if not starting from 0)
|
|
||||||
std: bool, if True, normalize the cost values in the CSV file using Gaussian normalization
|
|
||||||
|
|
||||||
Returns
|
|
||||||
----------
|
|
||||||
A: np.ndarray, adjacency matrix
|
|
||||||
distanceA: np.ndarray, distance matrix (normalized if std=True)
|
|
||||||
'''
|
|
||||||
if 'npy' in adj_path:
|
|
||||||
adj_mx = np.load(adj_path)
|
|
||||||
return adj_mx, None
|
|
||||||
|
|
||||||
else:
|
|
||||||
A = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
|
|
||||||
distanceA = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
|
|
||||||
|
|
||||||
# 如果提供了id_filename,说明节点ID不是从0开始的,需要重新映射
|
|
||||||
if id_filename:
|
|
||||||
with open(id_filename, 'r') as f:
|
|
||||||
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))}
|
|
||||||
|
|
||||||
with open(adj_path, 'r') as f:
|
|
||||||
f.readline() # 略过表头那一行
|
|
||||||
reader = csv.reader(f)
|
|
||||||
costs = [] # 用于收集所有cost值
|
|
||||||
for row in reader:
|
|
||||||
if len(row) != 3:
|
|
||||||
continue
|
|
||||||
i, j, distance = int(row[0]), int(row[1]), float(row[2])
|
|
||||||
A[id_dict[i], id_dict[j]] = 1
|
|
||||||
# 确保距离值为正
|
|
||||||
distance = max(distance, 1e-6)
|
|
||||||
costs.append(distance) # 收集cost值
|
|
||||||
distanceA[id_dict[i], id_dict[j]] = distance
|
|
||||||
|
|
||||||
else: # 如果没有提供id_filename,说明节点ID是从0开始的
|
|
||||||
with open(adj_path, 'r') as f:
|
|
||||||
f.readline() # 略过表头那一行
|
|
||||||
reader = csv.reader(f)
|
|
||||||
costs = [] # 用于收集所有cost值
|
|
||||||
for row in reader:
|
|
||||||
if len(row) != 3:
|
|
||||||
continue
|
|
||||||
i, j, distance = int(row[0]), int(row[1]), float(row[2])
|
|
||||||
A[i, j] = 1
|
|
||||||
# 确保距离值为正
|
|
||||||
distance = max(distance, 1e-6)
|
|
||||||
costs.append(distance) # 收集cost值
|
|
||||||
distanceA[i, j] = distance
|
|
||||||
|
|
||||||
# 如果std=True,对CSV中的所有cost值进行高斯正态分布标准化
|
|
||||||
if std:
|
|
||||||
mean_cost = np.mean(costs) # 计算cost值的均值
|
|
||||||
std_cost = np.std(costs) # 计算cost值的标准差
|
|
||||||
for idx in np.ndindex(distanceA.shape): # 遍历矩阵
|
|
||||||
if distanceA[idx] > 0: # 只对非零元素进行标准化
|
|
||||||
normalized_value = (distanceA[idx] - mean_cost) / std_cost
|
|
||||||
# 确保标准化后的值为正
|
|
||||||
normalized_value = max(normalized_value, 1e-6)
|
|
||||||
distanceA[idx] = normalized_value
|
|
||||||
|
|
||||||
# 确保矩阵中没有零行
|
|
||||||
row_sums = distanceA.sum(axis=1)
|
|
||||||
zero_rows = np.where(row_sums == 0)[0]
|
|
||||||
for row in zero_rows:
|
|
||||||
distanceA[row, :] = 1e-6 # 将零行替换为一个非零的默认值
|
|
||||||
|
|
||||||
return A, distanceA
|
|
||||||
|
|
||||||
|
|
||||||
def calc_gso(dir_adj, gso_type):
|
|
||||||
n_vertex = dir_adj.shape[0]
|
|
||||||
|
|
||||||
if not sp.issparse(dir_adj):
|
|
||||||
dir_adj = sp.csc_matrix(dir_adj)
|
|
||||||
elif dir_adj.format != 'csc':
|
|
||||||
dir_adj = dir_adj.tocsc()
|
|
||||||
|
|
||||||
id = sp.identity(n_vertex, format='csc')
|
|
||||||
|
|
||||||
# Symmetrizing an adjacency matrix
|
|
||||||
adj = dir_adj + dir_adj.T.multiply(dir_adj.T > dir_adj) - dir_adj.multiply(dir_adj.T > dir_adj)
|
|
||||||
# adj = 0.5 * (dir_adj + dir_adj.transpose())
|
|
||||||
|
|
||||||
if gso_type in ['sym_renorm_adj', 'rw_renorm_adj', 'sym_renorm_lap', 'rw_renorm_lap']:
|
|
||||||
adj = adj + id
|
|
||||||
|
|
||||||
if gso_type in ['sym_norm_adj', 'sym_renorm_adj', 'sym_norm_lap', 'sym_renorm_lap']:
|
|
||||||
row_sum = adj.sum(axis=1).A1
|
|
||||||
# Check for zero or negative values in row_sum
|
|
||||||
if np.any(row_sum <= 0):
|
|
||||||
raise ValueError(
|
|
||||||
"Row sum contains zero or negative values, which is not allowed for symmetric normalization.")
|
|
||||||
|
|
||||||
row_sum_inv_sqrt = np.power(row_sum, -0.5)
|
|
||||||
row_sum_inv_sqrt[np.isinf(row_sum_inv_sqrt)] = 0. # Handle inf values
|
|
||||||
deg_inv_sqrt = sp.diags(row_sum_inv_sqrt, format='csc')
|
|
||||||
# A_{sym} = D^{-0.5} * A * D^{-0.5}
|
|
||||||
sym_norm_adj = deg_inv_sqrt.dot(adj).dot(deg_inv_sqrt)
|
|
||||||
|
|
||||||
if gso_type in ['sym_norm_lap', 'sym_renorm_lap']:
|
|
||||||
sym_norm_lap = id - sym_norm_adj
|
|
||||||
gso = sym_norm_lap
|
|
||||||
else:
|
|
||||||
gso = sym_norm_adj
|
|
||||||
|
|
||||||
elif gso_type in ['rw_norm_adj', 'rw_renorm_adj', 'rw_norm_lap', 'rw_renorm_lap']:
|
|
||||||
row_sum = np.sum(adj, axis=1).A1
|
|
||||||
# Check for zero or negative values in row_sum
|
|
||||||
if np.any(row_sum <= 0):
|
|
||||||
raise ValueError(
|
|
||||||
"Row sum contains zero or negative values, which is not allowed for random walk normalization.")
|
|
||||||
|
|
||||||
row_sum_inv = np.power(row_sum, -1)
|
|
||||||
row_sum_inv[np.isinf(row_sum_inv)] = 0. # Handle inf values
|
|
||||||
deg_inv = sp.diags(row_sum_inv, format='csc')
|
|
||||||
# A_{rw} = D^{-1} * A
|
|
||||||
rw_norm_adj = deg_inv.dot(adj)
|
|
||||||
|
|
||||||
if gso_type in ['rw_norm_lap', 'rw_renorm_lap']:
|
|
||||||
rw_norm_lap = id - rw_norm_adj
|
|
||||||
gso = rw_norm_lap
|
|
||||||
else:
|
|
||||||
gso = rw_norm_adj
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f'{gso_type} is not defined.')
|
|
||||||
|
|
||||||
# Check for nan or inf in the final result
|
|
||||||
if np.isnan(gso.data).any() or np.isinf(gso.data).any():
|
|
||||||
raise ValueError("NaN or Inf detected in the final GSO matrix. Please check the input adjacency matrix.")
|
|
||||||
|
|
||||||
return gso
|
|
||||||
|
|
||||||
|
|
||||||
def calc_chebynet_gso(gso):
|
|
||||||
if sp.issparse(gso) == False:
|
|
||||||
gso = sp.csc_matrix(gso)
|
|
||||||
elif gso.format != 'csc':
|
|
||||||
gso = gso.tocsc()
|
|
||||||
|
|
||||||
id = sp.identity(gso.shape[0], format='csc')
|
|
||||||
# If you encounter a NotImplementedError, please update your scipy version to 1.10.1 or later.
|
|
||||||
eigval_max = norm(gso, 2)
|
|
||||||
|
|
||||||
# If the gso is symmetric or random walk normalized Laplacian,
|
|
||||||
# then the maximum eigenvalue is smaller than or equals to 2.
|
|
||||||
if eigval_max >= 2:
|
|
||||||
gso = gso - id
|
|
||||||
else:
|
|
||||||
gso = 2 * gso / eigval_max - id
|
|
||||||
|
|
||||||
return gso
|
|
||||||
Loading…
Reference in New Issue