Compare commits
6 Commits
9b3bb44552
...
8e53d25ab1
| Author | SHA1 | Date |
|---|---|---|
|
|
8e53d25ab1 | |
|
|
2ba061e57a | |
|
|
d0af46ea5f | |
|
|
f3480fccdc | |
|
|
7da402d5e0 | |
|
|
2800f66dfe |
103
README.md
103
README.md
|
|
@ -22,108 +22,23 @@ pip install -r requirements.txt
|
|||
pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook
|
||||
```
|
||||
|
||||
# 准备GPT预训练权重
|
||||
|
||||
需要海外网络,如果没有海外网络,手动下载后上传。
|
||||
|
||||
# 快速开始(暂时弃用)
|
||||
|
||||
参考baseline.ipynb中的命令执行,或者使用下面的命令:(请确保当前目录为TrafficWheel)
|
||||
GPT-2文件夹内应该有两个文件:`{config.json, pytorch_model.bin}`
|
||||
|
||||
```bash
|
||||
python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0}
|
||||
mkdir GPT-2
|
||||
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./GPT-2/config.json
|
||||
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./GPT-2/pytorch_model.bin
|
||||
```
|
||||
|
||||
- model_name: 目前支持:DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN、STAWnet
|
||||
- dataset_name目前支持:PEMSD3,PEMSD4、PEMSD7、PEMSD8
|
||||
- mode:train为训练模型,test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。
|
||||
- device: 支持'cpu'、'cuda:0'、‘cuda:1’ ... 取决于机器卡数
|
||||
|
||||
run.py会自动完成数据集下载、模型训练/评估工作。
|
||||
|
||||
:warning:现有的模型性能数据存放在[Result.xlsx](./Result.xlsx),不必浪费资源再跑一遍。
|
||||
|
||||
|
||||
|
||||
# 测试模型
|
||||
|
||||
在实验结束后,模型的存档文件会被保存在 `experiments/dataset/训练时间 `文件夹下,共有4个文件。
|
||||
|
||||
- best_model.pth 保存了使验证集效果最好的checkpoint
|
||||
- best_test_model.pth 保存了使测试集的效果最好的checkpoint
|
||||
- DATASET.yaml 存放了训练模型时所使用的参数
|
||||
- run.log 记录了训练日志。
|
||||
|
||||
可以创建`pre-train/{dataset_name}`文件夹,把整个文件夹的内容拷贝到`experiments/dataset/训练时间 `文件夹下的内容全部拷贝到`pre-train/{dataset_name}`里面。之后就可以在命令中调用` --model test`进行测试。
|
||||
|
||||
:warning:注意,请及时删除experiments文件夹中不必要的文件,要不整个文件夹会越堆越大。
|
||||
|
||||
|
||||
|
||||
# 更改配置
|
||||
|
||||
在config文件夹中,存放了每个模型的配置文件。每个数据集单独配置,使用yaml格式。
|
||||
|
||||
你需要找到对应模型的参数进行修改。
|
||||
|
||||
配置文件分为五个部分:[data, model, train, test, log]
|
||||
|
||||
- data一般不用改,存放了模型的节点数,预测窗口,历史窗口等信息
|
||||
- model中的参数要结合代码和论文看,一般会给出推荐配置。
|
||||
- train调整模型的训练细节,包括batch size,学习率,batch_norm等。
|
||||
|
||||
一般不建议对基线参数进行修改,按默认跑是最稳定的。
|
||||
|
||||
|
||||
|
||||
# 开发模型
|
||||
|
||||
首先你需要创建一个开发分支dev,并切换到开发分支
|
||||
# 跑REPST
|
||||
第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8/PEMS-BAY。
|
||||
|
||||
```bash
|
||||
git switch -c dev
|
||||
python run.py --config ./config/REPST/PEMS-BAY.yaml
|
||||
```
|
||||
|
||||
参考 [模型迁移教程](./transfer_guide.md) 迁移模型到TrafficWheel中。
|
||||
|
||||
提交更改。
|
||||
|
||||
```bash
|
||||
git add .
|
||||
git commit -m "Commit message"
|
||||
```
|
||||
|
||||
推送到远程仓库(需要找我注册账号)
|
||||
|
||||
```bash
|
||||
git push origin dev
|
||||
```
|
||||
|
||||
模型开发完成后,需要合并到main分支,在[这里](https://github.zhang-heng.com/czzhangheng/TrafficWheel/pulls)提交pull request。
|
||||
|
||||
|
||||
|
||||
# 已知问题
|
||||
|
||||
目前,实测以下模型性能与原报告相比指标偏高:ARIMA、TCN、DCRNN
|
||||
|
||||
STGCN在载入图时会有未知warning
|
||||
|
||||
以下模型由于没有源码暂未实现:HA、VAR、FC-LSTM、GRU-ED
|
||||
|
||||
|
||||
|
||||
# 源代码及论文
|
||||
|
||||
| 论文 | 代码 |
|
||||
| ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||
| [HierAttnLSTM](https://arxiv.org/abs/2201.05760v4) | [代码](https://github.com/TeRyZh/Network-Level-Travel-Prediction-Hierarchical-Attention-LSTM) |
|
||||
| [DSANET](https://dl.acm.org/doi/10.1145/3357384.3358132) | [代码](https://github.com/bighuang624/DSANet) |
|
||||
| [STGCN](https://arxiv.org/abs/1709.04875) | [代码](https://github.com/hazdzz/STGCN) |
|
||||
| [DCRNN](https://arxiv.org/abs/1707.01926) | [代码](https://github.com/chnsh/DCRNN_PyTorch) |
|
||||
| [GraphWaveNet](https://arxiv.org/pdf/1906.00121.pdf) | [代码](https://github.com/SGT-LIM/GraphWavenet) |
|
||||
| [STSGCN](https://aaai.org/ojs/index.php/AAAI/article/view/5438/5294) | [代码](https://github.com/SmallNana/STSGCN_Pytorch) |
|
||||
| [AGCRN](https://arxiv.org/pdf/2007.02842) | [代码](https://github.com/LeiBAI/AGCRN) |
|
||||
| [STFGNN](https://arxiv.org/abs/2012.09641) | [代码](https://github.com/lwm412/STFGNN-Pytorch) |
|
||||
| [STGODE](https://arxiv.org/abs/2106.12931) | [代码](https://github.com/square-coder/STGODE) |
|
||||
| [STG-NCDE](https://arxiv.org/abs/2112.03558) | [代码](https://github.com/jeongwhanchoi/STG-NCDE) |
|
||||
| [DDGCRN](https://www.sciencedirect.com/science/article/abs/pii/S0031320323003710) | [代码](https://github.com/wengwenchao123/DDGCRN) |
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ data:
|
|||
column_wise: false
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
horizon: 12
|
||||
lag: 12
|
||||
horizon: 24
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 325
|
||||
steps_per_day: 288
|
||||
|
|
@ -23,8 +23,8 @@ data:
|
|||
batch_size: 16
|
||||
|
||||
model:
|
||||
pred_len: 12
|
||||
seq_len: 12
|
||||
pred_len: 24
|
||||
seq_len: 24
|
||||
patch_len: 6
|
||||
stride: 7
|
||||
dropout: 0.2
|
||||
|
|
@ -33,6 +33,7 @@ model:
|
|||
gpt_path: ./GPT-2
|
||||
d_model: 64
|
||||
n_heads: 1
|
||||
input_dim: 1
|
||||
|
||||
train:
|
||||
batch_size: 16
|
||||
|
|
|
|||
|
|
@ -0,0 +1,66 @@
|
|||
basic:
|
||||
dataset: "PEMS-BAY"
|
||||
mode: "train"
|
||||
device: "cuda:0"
|
||||
model: "STID"
|
||||
|
||||
data:
|
||||
num_nodes: 325
|
||||
lag: 24
|
||||
horizon: 24
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
steps_per_day: 288
|
||||
days_per_week: 7
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
batch_size: 64
|
||||
|
||||
model:
|
||||
input_dim: 3
|
||||
output_dim: 1
|
||||
history: 24
|
||||
horizon: 24
|
||||
num_nodes: 325
|
||||
input_len: 24
|
||||
embed_dim: 32
|
||||
output_len: 24
|
||||
num_layer: 3
|
||||
if_node: True
|
||||
node_dim: 32
|
||||
if_T_i_D: True
|
||||
if_D_i_W: True
|
||||
temp_dim_tid: 32
|
||||
temp_dim_diw: 32
|
||||
time_of_day_size: 288
|
||||
day_of_week_size: 7
|
||||
batch_size: 64
|
||||
|
||||
|
||||
train:
|
||||
loss_func: mae
|
||||
seed: 1
|
||||
batch_size: 64
|
||||
epochs: 300
|
||||
lr_init: 0.002
|
||||
weight_decay: 0.0001
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: "1,50,80"
|
||||
early_stop: True
|
||||
early_stop_patience: 15
|
||||
grad_norm: False
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
debug: true
|
||||
output_dim: 1
|
||||
mae_thresh: null
|
||||
mape_thresh: 0.0
|
||||
log_step: 200
|
||||
plot: False
|
||||
|
|
@ -3,7 +3,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from model.DCRNN.dcrnn_cell import DCGRUCell
|
||||
from data.get_adj import get_adj
|
||||
from utils.get_adj import get_adj
|
||||
|
||||
|
||||
class Seq2SeqAttrs:
|
||||
|
|
|
|||
|
|
@ -15,12 +15,13 @@ class ReplicationPad1d(nn.Module):
|
|||
return output
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
def __init__(self, c_in, d_model, patch_num, input_dim):
|
||||
super(TokenEmbedding, self).__init__()
|
||||
padding = 1
|
||||
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
||||
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
||||
self.confusion_layer = nn.Linear(2, 1)
|
||||
|
||||
self.confusion_layer = nn.Linear(patch_num * input_dim, 1)
|
||||
# if air_quality
|
||||
# self.confusion_layer = nn.Linear(42, 1)
|
||||
|
||||
|
|
@ -34,19 +35,18 @@ class TokenEmbedding(nn.Module):
|
|||
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
||||
# 768,64,25
|
||||
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
||||
|
||||
x = self.confusion_layer(x)
|
||||
return x.reshape(b, n, -1)
|
||||
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
def __init__(self, d_model, patch_len, stride, dropout):
|
||||
def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim):
|
||||
super(PatchEmbedding, self).__init__()
|
||||
# Patching
|
||||
self.patch_len = patch_len
|
||||
self.stride = stride
|
||||
self.padding_patch_layer = ReplicationPad1d((0, stride))
|
||||
self.value_embedding = TokenEmbedding(patch_len, d_model)
|
||||
self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class repst(nn.Module):
|
|||
self.pred_len = configs['pred_len']
|
||||
self.seq_len = configs['seq_len']
|
||||
self.patch_len = configs['patch_len']
|
||||
self.input_dim = configs['input_dim']
|
||||
self.stride = configs['stride']
|
||||
self.dropout = configs['dropout']
|
||||
self.gpt_layers = configs['gpt_layers']
|
||||
|
|
@ -28,7 +29,7 @@ class repst(nn.Module):
|
|||
self.head_nf = self.d_ff * self.patch_nums
|
||||
|
||||
# 64,6,7,0.2
|
||||
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout)
|
||||
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
|
||||
|
||||
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
|
||||
self.gpts.h = self.gpts.h[:self.gpt_layers]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from data.get_adj import get_adj
|
||||
from utils.get_adj import get_adj
|
||||
|
||||
|
||||
class gcn_operation(nn.Module):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import torch.nn as nn
|
||||
|
||||
from model.STGCN import layers
|
||||
from data.get_adj import get_gso
|
||||
from utils.get_adj import get_gso
|
||||
|
||||
|
||||
class STGCNChebGraphConv(nn.Module):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import init
|
||||
from data.get_adj import get_adj
|
||||
from utils.get_adj import get_adj
|
||||
import numbers
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from model.ST-SSL.models import STSSL
|
||||
from model.ST-SSL.layers import STEncoder, MLP
|
||||
from data.get_adj import get_gso
|
||||
from model.ST-SSL.layers
|
||||
from utils.get_adj import get_gso
|
||||
|
||||
class STSSLModel(nn.Module):
|
||||
def __init__(self, args):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from data.get_adj import get_gso
|
||||
from utils.get_adj import get_gso
|
||||
|
||||
|
||||
class STSSLModel(nn.Module):
|
||||
|
|
|
|||
|
|
@ -10,3 +10,5 @@ torchdiffeq
|
|||
fastdtw
|
||||
notebook
|
||||
torchcde
|
||||
einops
|
||||
transformers
|
||||
9
run.py
9
run.py
|
|
@ -1,5 +1,7 @@
|
|||
import torch
|
||||
|
||||
from utils.Download_data import check_and_download_data
|
||||
data_complete = check_and_download_data()
|
||||
assert data_complete is not None, "数据集下载失败,请重试!"
|
||||
|
||||
# import time
|
||||
from config.args_parser import parse_args
|
||||
|
|
@ -58,10 +60,5 @@ def main():
|
|||
case _:
|
||||
raise ValueError(f"Unsupported mode: {args['basic']['mode']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from utils.Download_data import check_and_download_data
|
||||
|
||||
data_complete = check_and_download_data()
|
||||
assert data_complete is not None, "数据集下载失败,请重试!"
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import psutil
|
|||
import torch
|
||||
from utils.logger import get_logger
|
||||
from utils.loss_function import all_metrics
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class TrainingStats:
|
||||
|
|
@ -148,7 +149,8 @@ class Trainer:
|
|||
y_pred, y_true = [], []
|
||||
|
||||
with torch.set_grad_enabled(optimizer_step):
|
||||
for batch_idx, (data, target) in enumerate(dataloader):
|
||||
progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}")
|
||||
for batch_idx, (data, target) in progress_bar:
|
||||
start_time = time.time()
|
||||
|
||||
label = target[..., : self.args["output_dim"]]
|
||||
|
|
@ -175,6 +177,9 @@ class Trainer:
|
|||
y_pred.append(output.detach().cpu())
|
||||
y_true.append(label.detach().cpu())
|
||||
|
||||
# Update progress bar with the current loss
|
||||
progress_bar.set_postfix(loss=loss.item())
|
||||
|
||||
y_pred = torch.cat(y_pred, dim=0)
|
||||
y_true = torch.cat(y_true, dim=0)
|
||||
|
||||
|
|
|
|||
|
|
@ -61,11 +61,6 @@ def check_and_download_data():
|
|||
missing_adj = True
|
||||
missing_main_files = True
|
||||
else:
|
||||
# 检查根目录下的 get_adj.py 文件
|
||||
if "get_adj.py" not in os.listdir(data_dir):
|
||||
# print(f"根目录下缺少文件 get_adj.py。")
|
||||
missing_adj = True
|
||||
|
||||
# 遍历预期的文件结构
|
||||
for subfolder, expected_files in expected_structure.items():
|
||||
subfolder_path = os.path.join(data_dir, subfolder)
|
||||
|
|
@ -100,8 +95,6 @@ def check_and_download_data():
|
|||
|
||||
rearrange_dir()
|
||||
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,218 @@
|
|||
import csv
|
||||
import os
|
||||
import numpy as np
|
||||
from scipy.sparse.linalg import norm
|
||||
import scipy.sparse as sp
|
||||
import torch
|
||||
|
||||
def get_adj(args):
|
||||
dataset_path = './data'
|
||||
match args['num_nodes']:
|
||||
case 358:
|
||||
dataset_name = 'PEMS03'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
|
||||
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
|
||||
case 307:
|
||||
dataset_name = 'PEMS04'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
||||
case 883:
|
||||
dataset_name = 'PEMS07'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path)
|
||||
case 170:
|
||||
dataset_name = 'PEMS08'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
||||
|
||||
return adj
|
||||
|
||||
def get_gso(args):
|
||||
dataset_path = './data'
|
||||
match args['num_nodes']:
|
||||
case 358:
|
||||
dataset_name = 'PEMS03'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
|
||||
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
|
||||
case 307:
|
||||
dataset_name = 'PEMS04'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
||||
case 883:
|
||||
dataset_name = 'PEMS07'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path)
|
||||
case 170:
|
||||
dataset_name = 'PEMS08'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
||||
|
||||
gso = calc_gso(adj, args['gso_type'])
|
||||
if args['graph_conv_type'] == 'cheb_graph_conv':
|
||||
gso = calc_chebynet_gso(gso)
|
||||
gso = gso.toarray()
|
||||
gso = gso.astype(dtype=np.float32)
|
||||
gso = torch.from_numpy(gso).to(args['device'])
|
||||
return gso
|
||||
|
||||
def load_adj(num_nodes, adj_path, id_filename=None, std=False):
|
||||
'''
|
||||
Parameters
|
||||
----------
|
||||
adj_path: str, path of the csv file contains edges information
|
||||
num_nodes: int, the number of vertices
|
||||
id_filename: str, optional, path of the file containing node IDs (if not starting from 0)
|
||||
std: bool, if True, normalize the cost values in the CSV file using Gaussian normalization
|
||||
|
||||
Returns
|
||||
----------
|
||||
A: np.ndarray, adjacency matrix
|
||||
distanceA: np.ndarray, distance matrix (normalized if std=True)
|
||||
'''
|
||||
if 'npy' in adj_path:
|
||||
adj_mx = np.load(adj_path)
|
||||
return adj_mx, None
|
||||
|
||||
else:
|
||||
A = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
|
||||
distanceA = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
|
||||
|
||||
# 如果提供了id_filename,说明节点ID不是从0开始的,需要重新映射
|
||||
if id_filename:
|
||||
with open(id_filename, 'r') as f:
|
||||
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))}
|
||||
|
||||
with open(adj_path, 'r') as f:
|
||||
f.readline() # 略过表头那一行
|
||||
reader = csv.reader(f)
|
||||
costs = [] # 用于收集所有cost值
|
||||
for row in reader:
|
||||
if len(row) != 3:
|
||||
continue
|
||||
i, j, distance = int(row[0]), int(row[1]), float(row[2])
|
||||
A[id_dict[i], id_dict[j]] = 1
|
||||
# 确保距离值为正
|
||||
distance = max(distance, 1e-6)
|
||||
costs.append(distance) # 收集cost值
|
||||
distanceA[id_dict[i], id_dict[j]] = distance
|
||||
|
||||
else: # 如果没有提供id_filename,说明节点ID是从0开始的
|
||||
with open(adj_path, 'r') as f:
|
||||
f.readline() # 略过表头那一行
|
||||
reader = csv.reader(f)
|
||||
costs = [] # 用于收集所有cost值
|
||||
for row in reader:
|
||||
if len(row) != 3:
|
||||
continue
|
||||
i, j, distance = int(row[0]), int(row[1]), float(row[2])
|
||||
A[i, j] = 1
|
||||
# 确保距离值为正
|
||||
distance = max(distance, 1e-6)
|
||||
costs.append(distance) # 收集cost值
|
||||
distanceA[i, j] = distance
|
||||
|
||||
# 如果std=True,对CSV中的所有cost值进行高斯正态分布标准化
|
||||
if std:
|
||||
mean_cost = np.mean(costs) # 计算cost值的均值
|
||||
std_cost = np.std(costs) # 计算cost值的标准差
|
||||
for idx in np.ndindex(distanceA.shape): # 遍历矩阵
|
||||
if distanceA[idx] > 0: # 只对非零元素进行标准化
|
||||
normalized_value = (distanceA[idx] - mean_cost) / std_cost
|
||||
# 确保标准化后的值为正
|
||||
normalized_value = max(normalized_value, 1e-6)
|
||||
distanceA[idx] = normalized_value
|
||||
|
||||
# 确保矩阵中没有零行
|
||||
row_sums = distanceA.sum(axis=1)
|
||||
zero_rows = np.where(row_sums == 0)[0]
|
||||
for row in zero_rows:
|
||||
distanceA[row, :] = 1e-6 # 将零行替换为一个非零的默认值
|
||||
|
||||
return A, distanceA
|
||||
|
||||
|
||||
def calc_gso(dir_adj, gso_type):
|
||||
n_vertex = dir_adj.shape[0]
|
||||
|
||||
if not sp.issparse(dir_adj):
|
||||
dir_adj = sp.csc_matrix(dir_adj)
|
||||
elif dir_adj.format != 'csc':
|
||||
dir_adj = dir_adj.tocsc()
|
||||
|
||||
id = sp.identity(n_vertex, format='csc')
|
||||
|
||||
# Symmetrizing an adjacency matrix
|
||||
adj = dir_adj + dir_adj.T.multiply(dir_adj.T > dir_adj) - dir_adj.multiply(dir_adj.T > dir_adj)
|
||||
# adj = 0.5 * (dir_adj + dir_adj.transpose())
|
||||
|
||||
if gso_type in ['sym_renorm_adj', 'rw_renorm_adj', 'sym_renorm_lap', 'rw_renorm_lap']:
|
||||
adj = adj + id
|
||||
|
||||
if gso_type in ['sym_norm_adj', 'sym_renorm_adj', 'sym_norm_lap', 'sym_renorm_lap']:
|
||||
row_sum = adj.sum(axis=1).A1
|
||||
# Check for zero or negative values in row_sum
|
||||
if np.any(row_sum <= 0):
|
||||
raise ValueError(
|
||||
"Row sum contains zero or negative values, which is not allowed for symmetric normalization.")
|
||||
|
||||
row_sum_inv_sqrt = np.power(row_sum, -0.5)
|
||||
row_sum_inv_sqrt[np.isinf(row_sum_inv_sqrt)] = 0. # Handle inf values
|
||||
deg_inv_sqrt = sp.diags(row_sum_inv_sqrt, format='csc')
|
||||
# A_{sym} = D^{-0.5} * A * D^{-0.5}
|
||||
sym_norm_adj = deg_inv_sqrt.dot(adj).dot(deg_inv_sqrt)
|
||||
|
||||
if gso_type in ['sym_norm_lap', 'sym_renorm_lap']:
|
||||
sym_norm_lap = id - sym_norm_adj
|
||||
gso = sym_norm_lap
|
||||
else:
|
||||
gso = sym_norm_adj
|
||||
|
||||
elif gso_type in ['rw_norm_adj', 'rw_renorm_adj', 'rw_norm_lap', 'rw_renorm_lap']:
|
||||
row_sum = np.sum(adj, axis=1).A1
|
||||
# Check for zero or negative values in row_sum
|
||||
if np.any(row_sum <= 0):
|
||||
raise ValueError(
|
||||
"Row sum contains zero or negative values, which is not allowed for random walk normalization.")
|
||||
|
||||
row_sum_inv = np.power(row_sum, -1)
|
||||
row_sum_inv[np.isinf(row_sum_inv)] = 0. # Handle inf values
|
||||
deg_inv = sp.diags(row_sum_inv, format='csc')
|
||||
# A_{rw} = D^{-1} * A
|
||||
rw_norm_adj = deg_inv.dot(adj)
|
||||
|
||||
if gso_type in ['rw_norm_lap', 'rw_renorm_lap']:
|
||||
rw_norm_lap = id - rw_norm_adj
|
||||
gso = rw_norm_lap
|
||||
else:
|
||||
gso = rw_norm_adj
|
||||
|
||||
else:
|
||||
raise ValueError(f'{gso_type} is not defined.')
|
||||
|
||||
# Check for nan or inf in the final result
|
||||
if np.isnan(gso.data).any() or np.isinf(gso.data).any():
|
||||
raise ValueError("NaN or Inf detected in the final GSO matrix. Please check the input adjacency matrix.")
|
||||
|
||||
return gso
|
||||
|
||||
|
||||
def calc_chebynet_gso(gso):
|
||||
if sp.issparse(gso) == False:
|
||||
gso = sp.csc_matrix(gso)
|
||||
elif gso.format != 'csc':
|
||||
gso = gso.tocsc()
|
||||
|
||||
id = sp.identity(gso.shape[0], format='csc')
|
||||
# If you encounter a NotImplementedError, please update your scipy version to 1.10.1 or later.
|
||||
eigval_max = norm(gso, 2)
|
||||
|
||||
# If the gso is symmetric or random walk normalized Laplacian,
|
||||
# then the maximum eigenvalue is smaller than or equals to 2.
|
||||
if eigval_max >= 2:
|
||||
gso = gso - id
|
||||
else:
|
||||
gso = 2 * gso / eigval_max - id
|
||||
|
||||
return gso
|
||||
Loading…
Reference in New Issue