兼容BeijingAirQuality。重构data,需要更新pip requirement
This commit is contained in:
parent
9911caa3d8
commit
a46edc79a5
|
|
@ -52,6 +52,14 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": "--config ./config/REPST/SolarEnergy.yaml"
|
||||
},
|
||||
{
|
||||
"name": "BeijingAirQuality",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "run.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": "--config ./config/REPST/BeijingAirQuality.yaml"
|
||||
},
|
||||
{
|
||||
"name": "AEPSA-PEMSBAY",
|
||||
"type": "debugpy",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,11 @@
|
|||
{
|
||||
"python-envs.defaultEnvManager": "ms-python.python:system",
|
||||
"python-envs.defaultPackageManager": "ms-python.python:pip",
|
||||
"python-envs.pythonProjects": []
|
||||
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
||||
"python-envs.defaultPackageManager": "ms-python.python:conda",
|
||||
"python-envs.pythonProjects": [
|
||||
{
|
||||
"path": "data/SolarEnergy",
|
||||
"envManager": "ms-python.python:system",
|
||||
"packageManager": "ms-python.python:pip"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
basic:
|
||||
dataset: "BeijingAirQuality"
|
||||
mode : "train"
|
||||
device : "cuda:1"
|
||||
model: "REPST"
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
add_day_in_week: false
|
||||
add_time_in_day: false
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
horizon: 12
|
||||
lag: 12
|
||||
normalizer: std
|
||||
num_nodes: 7
|
||||
steps_per_day: 288
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_ratio: 0.2
|
||||
sample: 1
|
||||
input_dim: 3
|
||||
batch_size: 16
|
||||
|
||||
model:
|
||||
pred_len: 12
|
||||
seq_len: 12
|
||||
patch_len: 6
|
||||
stride: 7
|
||||
dropout: 0.2
|
||||
gpt_layers: 9
|
||||
d_ff: 128
|
||||
gpt_path: ./GPT-2
|
||||
d_model: 64
|
||||
n_heads: 1
|
||||
input_dim: 3
|
||||
output_dim: 3
|
||||
word_num: 1000
|
||||
|
||||
train:
|
||||
batch_size: 16
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: "5,20,40,70"
|
||||
lr_init: 0.003
|
||||
max_grad_norm: 5
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
debug: false
|
||||
output_dim: 3
|
||||
log_step: 1000
|
||||
plot: false
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
|
||||
|
|
@ -7,6 +7,11 @@ def load_st_dataset(config):
|
|||
# sample = config["data"]["sample"]
|
||||
# output B, N, D
|
||||
match dataset:
|
||||
case "BeijingAirQuality":
|
||||
data_path = os.path.join("./data/BeijingAirQuality/data.dat")
|
||||
data = np.memmap(data_path, dtype=np.float32, mode='r')
|
||||
L, N, C = 36000, 7, 3
|
||||
data = data.reshape(L, N, C)
|
||||
case "PEMS-BAY":
|
||||
data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5")
|
||||
with h5py.File(data_path, 'r') as f:
|
||||
|
|
|
|||
|
|
@ -15,13 +15,13 @@ class ReplicationPad1d(nn.Module):
|
|||
return output
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model, patch_num, input_dim):
|
||||
def __init__(self, c_in, d_model, patch_num, input_dim, output_dim):
|
||||
super(TokenEmbedding, self).__init__()
|
||||
padding = 1
|
||||
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
||||
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
||||
|
||||
self.confusion_layer = nn.Linear(patch_num * input_dim, 1)
|
||||
self.confusion_layer = nn.Linear(patch_num * input_dim, output_dim)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
|
|
@ -37,17 +37,16 @@ class TokenEmbedding(nn.Module):
|
|||
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim):
|
||||
def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim, output_dim):
|
||||
super(PatchEmbedding, self).__init__()
|
||||
# Patching
|
||||
self.patch_len = patch_len
|
||||
self.stride = stride
|
||||
self.padding_patch_layer = ReplicationPad1d((0, stride))
|
||||
self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim)
|
||||
self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim, output_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
n_vars = x.shape[2]
|
||||
x = self.padding_patch_layer(x)
|
||||
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ class repst(nn.Module):
|
|||
self.gpt_layers = configs['gpt_layers']
|
||||
self.d_ff = configs['d_ff']
|
||||
self.gpt_path = configs['gpt_path']
|
||||
self.output_dim = configs.get('output_dim', 1)
|
||||
|
||||
self.word_choice = GumbelSoftmax(configs['word_num'])
|
||||
|
||||
|
|
@ -31,7 +32,7 @@ class repst(nn.Module):
|
|||
self.head_nf = self.d_ff * self.patch_nums
|
||||
|
||||
# 词嵌入
|
||||
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
|
||||
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim, self.output_dim)
|
||||
|
||||
# GPT2初始化
|
||||
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
|
||||
|
|
@ -41,7 +42,7 @@ class repst(nn.Module):
|
|||
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
|
||||
self.vocab_size = self.word_embeddings.shape[0]
|
||||
self.mapping_layer = nn.Linear(self.vocab_size, 1)
|
||||
self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm)
|
||||
self.reprogramming_layer = ReprogrammingLayer(self.d_model * self.output_dim, self.n_heads, self.d_keys, self.d_llm)
|
||||
|
||||
self.out_mlp = nn.Sequential(
|
||||
nn.Linear(self.d_llm, 128),
|
||||
|
|
@ -62,7 +63,7 @@ class repst(nn.Module):
|
|||
torch.nn.init.zeros_(module.bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = x[..., :1]
|
||||
x = x[..., :self.output_dim]
|
||||
x_enc = rearrange(x, 'b t n c -> b n c t')
|
||||
enc_out, n_vars = self.patch_embedding(x_enc)
|
||||
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
||||
|
|
@ -79,25 +80,3 @@ class repst(nn.Module):
|
|||
|
||||
return outputs
|
||||
|
||||
if __name__ == '__main__':
|
||||
configs = {
|
||||
'device': 'cuda:0',
|
||||
'pred_len': 24,
|
||||
'seq_len': 24,
|
||||
'patch_len': 6,
|
||||
'stride': 7,
|
||||
'dropout': 0.2,
|
||||
'gpt_layers': 9,
|
||||
'd_ff': 128,
|
||||
'gpt_path': './GPT-2',
|
||||
'd_model': 64,
|
||||
'n_heads': 1,
|
||||
'input_dim': 1
|
||||
}
|
||||
model = repst(configs)
|
||||
x = torch.randn(16, 24, 325, 1)
|
||||
y = model(x)
|
||||
|
||||
print(y.shape)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,3 +12,4 @@ notebook
|
|||
torchcde
|
||||
einops
|
||||
transformers
|
||||
py7zr
|
||||
2
run.py
2
run.py
|
|
@ -15,7 +15,7 @@ def main():
|
|||
args = init.init_device(args)
|
||||
init.init_seed(args["basic"]["seed"])
|
||||
|
||||
|
||||
# Load model
|
||||
model = init.init_model(args)
|
||||
|
||||
# Load dataset
|
||||
|
|
|
|||
|
|
@ -1,228 +1,114 @@
|
|||
import os
|
||||
import requests
|
||||
import zipfile
|
||||
import shutil
|
||||
import kagglehub # 假设 kagglehub 是一个可用的库
|
||||
import os, json, shutil, requests
|
||||
from urllib.parse import urlsplit
|
||||
from tqdm import tqdm
|
||||
import kagglehub
|
||||
import py7zr
|
||||
|
||||
# 定义文件完整性信息的字典
|
||||
# ---------- 1. 加载结构 JSON ----------
|
||||
def load_structure_json(path="utils/dataset.json"):
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def check_and_download_data():
|
||||
"""
|
||||
检查 data 文件夹的完整性,并根据缺失文件类型下载相应数据。
|
||||
"""
|
||||
current_working_dir = os.getcwd() # 获取当前工作目录
|
||||
data_dir = os.path.join(
|
||||
current_working_dir, "data"
|
||||
) # 假设 data 文件夹在当前工作目录下
|
||||
|
||||
expected_structure = {
|
||||
"PEMS03": [
|
||||
"PEMS03.csv",
|
||||
"PEMS03.npz",
|
||||
"PEMS03.txt",
|
||||
"PEMS03_dtw_distance.npy",
|
||||
"PEMS03_spatial_distance.npy",
|
||||
],
|
||||
"PEMS04": [
|
||||
"PEMS04.csv",
|
||||
"PEMS04.npz",
|
||||
"PEMS04_dtw_distance.npy",
|
||||
"PEMS04_spatial_distance.npy",
|
||||
],
|
||||
"PEMS07": [
|
||||
"PEMS07.csv",
|
||||
"PEMS07.npz",
|
||||
"PEMS07_dtw_distance.npy",
|
||||
"PEMS07_spatial_distance.npy",
|
||||
],
|
||||
"PEMS08": [
|
||||
"PEMS08.csv",
|
||||
"PEMS08.npz",
|
||||
"PEMS08_dtw_distance.npy",
|
||||
"PEMS08_spatial_distance.npy",
|
||||
],
|
||||
"PEMS-BAY": [
|
||||
"adj_mx_bay.pkl",
|
||||
"pems-bay-meta.h5",
|
||||
"pems-bay.h5"
|
||||
],
|
||||
"METR-LA": [
|
||||
"METR-LA.h5"
|
||||
],
|
||||
"SolarEnergy": [
|
||||
|
||||
]
|
||||
}
|
||||
|
||||
current_dir = os.getcwd() # 获取当前工作目录
|
||||
missing_adj = False
|
||||
missing_main_files = False
|
||||
|
||||
# 检查 data 文件夹是否存在
|
||||
if not os.path.exists(data_dir) or not os.path.isdir(data_dir):
|
||||
# print(f"目录 {data_dir} 不存在。")
|
||||
print("正在下载所有必要的数据文件...")
|
||||
# ---------- 2. 检测完整性 ----------
|
||||
def detect_data_integrity(data_dir, expected, check_adj=False):
|
||||
missing_adj, missing_main = False, False
|
||||
if not os.path.isdir(data_dir): return True, True
|
||||
for folder, files in expected.items():
|
||||
folder_path = os.path.join(data_dir, folder)
|
||||
if not os.path.isdir(folder_path):
|
||||
if check_adj:
|
||||
missing_adj = True
|
||||
missing_main_files = True
|
||||
else:
|
||||
# 遍历预期的文件结构
|
||||
for subfolder, expected_files in expected_structure.items():
|
||||
subfolder_path = os.path.join(data_dir, subfolder)
|
||||
|
||||
# 检查子文件夹是否存在
|
||||
if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path):
|
||||
# print(f"子文件夹 {subfolder} 不存在。")
|
||||
missing_main_files = True
|
||||
continue
|
||||
|
||||
# 获取子文件夹中的实际文件列表
|
||||
actual_files = os.listdir(subfolder_path)
|
||||
|
||||
# 检查是否缺少文件
|
||||
for expected_file in expected_files:
|
||||
if expected_file not in actual_files:
|
||||
# print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。")
|
||||
if (
|
||||
"_dtw_distance.npy" in expected_file
|
||||
or "_spatial_distance.npy" in expected_file
|
||||
):
|
||||
missing_main = True
|
||||
continue
|
||||
existing = set(os.listdir(folder_path))
|
||||
for f in files:
|
||||
if f not in existing:
|
||||
if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")):
|
||||
missing_adj = True
|
||||
else:
|
||||
missing_main_files = True
|
||||
elif not check_adj:
|
||||
missing_main = True
|
||||
return missing_adj, missing_main
|
||||
|
||||
# 根据缺失文件类型调用下载逻辑
|
||||
# ---------- 3. 下载 7z 并解压 ----------
|
||||
def download_and_extract(url, dst_dir, max_retries=3):
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
filename = os.path.basename(urlsplit(url).path) or "download.7z"
|
||||
file_path = os.path.join(dst_dir, filename)
|
||||
for attempt in range(1, max_retries+1):
|
||||
try:
|
||||
# 下载
|
||||
with requests.get(url, stream=True, timeout=30) as r:
|
||||
r.raise_for_status()
|
||||
total = int(r.headers.get("content-length",0))
|
||||
with open(file_path,"wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=filename) as bar:
|
||||
for chunk in r.iter_content(8192):
|
||||
f.write(chunk)
|
||||
bar.update(len(chunk))
|
||||
# 解压 7z
|
||||
with py7zr.SevenZipFile(file_path, mode='r') as archive:
|
||||
archive.extractall(path=dst_dir)
|
||||
os.remove(file_path)
|
||||
return
|
||||
except Exception as e:
|
||||
if attempt==max_retries: raise RuntimeError("下载或解压失败")
|
||||
print("错误,重试中...", e)
|
||||
|
||||
# ---------- 4. 下载 Kaggle 数据 ----------
|
||||
def download_kaggle_data(base_dir, dataset):
|
||||
try:
|
||||
path = kagglehub.dataset_download(dataset)
|
||||
shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True)
|
||||
except Exception as e:
|
||||
print("Kaggle 下载失败:", dataset, e)
|
||||
|
||||
# ---------- 5. 整理目录 ----------
|
||||
def rearrange_dir():
|
||||
data_dir = os.path.join(os.getcwd(), "data")
|
||||
nested = os.path.join(data_dir,"data")
|
||||
if os.path.isdir(nested):
|
||||
for item in os.listdir(nested):
|
||||
src,dst = os.path.join(nested,item), os.path.join(data_dir,item)
|
||||
if os.path.isdir(src):
|
||||
shutil.copytree(src, dst, dirs_exist_ok=True) # 更新已存在的目录
|
||||
else:
|
||||
shutil.copy2(src, dst)
|
||||
shutil.rmtree(nested)
|
||||
|
||||
for kw,tgt in [("bay","PEMS-BAY"),("metr","METR-LA")]:
|
||||
dst = os.path.join(data_dir,tgt); os.makedirs(dst,exist_ok=True)
|
||||
for f in os.listdir(data_dir):
|
||||
if kw in f.lower() and f.endswith((".h5",".pkl")):
|
||||
shutil.move(os.path.join(data_dir,f), os.path.join(dst,f))
|
||||
|
||||
solar = os.path.join(data_dir,"solar-energy")
|
||||
if os.path.isdir(solar):
|
||||
dst = os.path.join(data_dir,"SolarEnergy"); os.makedirs(dst,exist_ok=True)
|
||||
csv = os.path.join(solar,"solar_AL.csv")
|
||||
if os.path.isfile(csv): shutil.move(csv, os.path.join(dst,"SolarEnergy.csv"))
|
||||
shutil.rmtree(solar)
|
||||
|
||||
# ---------- 6. 主流程 ----------
|
||||
def check_and_download_data():
|
||||
cwd = os.getcwd()
|
||||
data_dir = os.path.join(cwd,"data")
|
||||
expected = load_structure_json()
|
||||
|
||||
missing_adj,_ = detect_data_integrity(data_dir, expected, check_adj=True)
|
||||
if missing_adj:
|
||||
download_adj_data(current_dir)
|
||||
if missing_main_files:
|
||||
download_kaggle_data(current_dir, 'elmahy/pems-dataset')
|
||||
download_kaggle_data(current_dir, 'scchuy/pemsbay')
|
||||
download_kaggle_data(current_dir, "annnnguyen/metr-la-dataset")
|
||||
download_kaggle_data(current_dir, "wangshaoqi/solar-energy")
|
||||
download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir)
|
||||
|
||||
baq_folder = os.path.join(data_dir,"BeijingAirQuality")
|
||||
if not os.path.isdir(baq_folder):
|
||||
download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir)
|
||||
|
||||
_,missing_main = detect_data_integrity(data_dir, expected, check_adj=False)
|
||||
if missing_main:
|
||||
for ds in ["elmahy/pems-dataset","scchuy/pemsbay","annnnguyen/metr-la-dataset","wangshaoqi/solar-energy"]:
|
||||
download_kaggle_data(cwd, ds)
|
||||
|
||||
rearrange_dir()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def download_adj_data(current_dir, max_retries=3):
|
||||
"""
|
||||
下载并解压 adj.zip 文件,并显示下载进度条。
|
||||
如果下载失败,最多重试 max_retries 次。
|
||||
"""
|
||||
url = "http://code.zhang-heng.com/static/adj.zip"
|
||||
retries = 0
|
||||
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
print(f"正在从 {url} 下载邻接矩阵文件...")
|
||||
response = requests.get(url, stream=True)
|
||||
|
||||
if response.status_code == 200:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 1024 # 1KB
|
||||
t = tqdm(total=total_size, unit="B", unit_scale=True, desc="下载进度")
|
||||
|
||||
zip_file_path = os.path.join(current_dir, "adj.zip")
|
||||
with open(zip_file_path, "wb") as f:
|
||||
for data in response.iter_content(block_size):
|
||||
f.write(data)
|
||||
t.update(len(data))
|
||||
t.close()
|
||||
|
||||
# print("下载完成,文件已保存到:", zip_file_path)
|
||||
|
||||
if os.path.exists(zip_file_path):
|
||||
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
|
||||
zip_ref.extractall(current_dir)
|
||||
# print("数据集已解压到:", current_dir)
|
||||
os.remove(zip_file_path) # 删除zip文件
|
||||
else:
|
||||
print("未找到下载的zip文件,跳过解压。")
|
||||
break # 下载成功,退出循环
|
||||
else:
|
||||
print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。")
|
||||
except Exception as e:
|
||||
print(f"下载或解压数据集时出错: {e}")
|
||||
print("如果链接无效,请检查URL的合法性或稍后重试。")
|
||||
|
||||
retries += 1
|
||||
if retries > max_retries:
|
||||
raise Exception(
|
||||
f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。"
|
||||
)
|
||||
|
||||
|
||||
def download_kaggle_data(current_dir, kaggle_path):
|
||||
"""
|
||||
下载 KaggleHub 数据集,并将数据直接移动到当前工作目录的 data 文件夹。
|
||||
如果目标文件夹已存在,会覆盖冲突的文件。
|
||||
"""
|
||||
try:
|
||||
print(f"正在下载 {kaggle_path} 数据集...")
|
||||
path = kagglehub.dataset_download(kaggle_path)
|
||||
# print("Path to KaggleHub dataset files:", path)
|
||||
|
||||
if os.path.exists(path):
|
||||
destination_path = os.path.join(current_dir, "data")
|
||||
# 使用 shutil.copytree 将文件夹内容直接放在 data 文件夹下,覆盖冲突的文件
|
||||
shutil.copytree(path, destination_path, dirs_exist_ok=True)
|
||||
except Exception as e:
|
||||
print(f"下载或处理 KaggleHub 数据集时出错: {e}")
|
||||
|
||||
|
||||
|
||||
def rearrange_dir():
|
||||
"""
|
||||
将 data/data 中的文件合并到上级目录,并删除 data/data 目录。
|
||||
"""
|
||||
data_dir = os.path.join(os.getcwd(), "data")
|
||||
nested_data_dir = os.path.join(data_dir, "data")
|
||||
|
||||
if os.path.exists(nested_data_dir) and os.path.isdir(nested_data_dir):
|
||||
for item in os.listdir(nested_data_dir):
|
||||
source_path = os.path.join(nested_data_dir, item)
|
||||
destination_path = os.path.join(data_dir, item)
|
||||
|
||||
if os.path.isdir(source_path):
|
||||
shutil.copytree(source_path, destination_path, dirs_exist_ok=True)
|
||||
else:
|
||||
shutil.copy2(source_path, destination_path)
|
||||
|
||||
shutil.rmtree(nested_data_dir)
|
||||
|
||||
# 将带有 "bay" 的文件移动到 PEMS-BAY 文件夹
|
||||
pems_bay_dir = os.path.join(data_dir, "PEMS-BAY")
|
||||
os.makedirs(pems_bay_dir, exist_ok=True)
|
||||
|
||||
for item in os.listdir(data_dir):
|
||||
if "bay" in item.lower() and (item.endswith(".pkl") or item.endswith(".h5")):
|
||||
source_path = os.path.join(data_dir, item)
|
||||
destination_path = os.path.join(pems_bay_dir, item)
|
||||
shutil.move(source_path, destination_path)
|
||||
|
||||
# metr-la
|
||||
metrla_dir = os.path.join(data_dir, "METR-LA")
|
||||
os.makedirs(metrla_dir, exist_ok=True)
|
||||
for item in os.listdir(data_dir):
|
||||
if "metr" in item.lower() and (item.endswith(".pkl") or item.endswith(".h5")):
|
||||
source_path = os.path.join(data_dir, item)
|
||||
destination_path = os.path.join(metrla_dir, item)
|
||||
shutil.move(source_path, destination_path)
|
||||
|
||||
# solar-energy
|
||||
solar_src = os.path.join(data_dir, "solar-energy")
|
||||
solar_sub = os.path.join(solar_src, "solar_AL.txt")
|
||||
solar_csv = os.path.join(solar_src, "solar_AL.csv")
|
||||
solar_dst_dir = os.path.join(data_dir,"SolarEnergy")
|
||||
solar_dst_csv = os.path.join(solar_dst_dir, "SolarEnergy.csv")
|
||||
if os.path.isdir(solar_sub): shutil.rmtree(solar_sub)
|
||||
if os.path.isdir(solar_src): os.rename(solar_src, solar_dst_dir)
|
||||
if os.path.isfile(solar_csv.replace(solar_src, solar_dst_dir)):
|
||||
os.rename(solar_csv.replace(solar_src, solar_dst_dir), solar_dst_csv)
|
||||
|
||||
|
||||
# 主程序
|
||||
if __name__ == "__main__":
|
||||
if __name__=="__main__":
|
||||
check_and_download_data()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
{
|
||||
"PEMS03": [
|
||||
"PEMS03.csv",
|
||||
"PEMS03.npz",
|
||||
"PEMS03.txt",
|
||||
"PEMS03_dtw_distance.npy",
|
||||
"PEMS03_spatial_distance.npy"
|
||||
],
|
||||
"PEMS04": [
|
||||
"PEMS04.csv",
|
||||
"PEMS04.npz",
|
||||
"PEMS04_dtw_distance.npy",
|
||||
"PEMS04_spatial_distance.npy"
|
||||
],
|
||||
"PEMS07": [
|
||||
"PEMS07.csv",
|
||||
"PEMS07.npz",
|
||||
"PEMS07_dtw_distance.npy",
|
||||
"PEMS07_spatial_distance.npy"
|
||||
],
|
||||
"PEMS08": [
|
||||
"PEMS08.csv",
|
||||
"PEMS08.npz",
|
||||
"PEMS08_dtw_distance.npy",
|
||||
"PEMS08_spatial_distance.npy"
|
||||
],
|
||||
"PEMS-BAY": [
|
||||
"adj_mx_bay.pkl",
|
||||
"pems-bay-meta.h5",
|
||||
"pems-bay.h5"
|
||||
],
|
||||
"METR-LA": [
|
||||
"METR-LA.h5"
|
||||
],
|
||||
"SolarEnergy": [
|
||||
"SolarEnergy.csv"
|
||||
],
|
||||
"BeijingAirQuality": ["data.dat", "desc.json"]
|
||||
}
|
||||
Loading…
Reference in New Issue