AEPSA v0.1
This commit is contained in:
parent
88daa53dd5
commit
9c50c30918
|
|
@ -3,9 +3,11 @@
|
||||||
// 悬停以查看现有属性的描述。
|
// 悬停以查看现有属性的描述。
|
||||||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
|
|
||||||
"configurations": [
|
"configurations": [
|
||||||
|
// STID 模型组
|
||||||
{
|
{
|
||||||
"name": "STID_PEMS-BAY",
|
"name": "STID: PEMS-BAY",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
|
|
@ -13,7 +15,7 @@
|
||||||
"args": "--config ./config/STID/PEMS-BAY.yaml"
|
"args": "--config ./config/STID/PEMS-BAY.yaml"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "STID_PEMSD4",
|
"name": "STID: PEMSD4",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
|
|
@ -21,15 +23,7 @@
|
||||||
"args": "--config ./config/STID/PEMSD4.yaml"
|
"args": "--config ./config/STID/PEMSD4.yaml"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "REPST",
|
"name": "STID: BJTaxi-InFlow",
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "run.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"args": "--config ./config/REPST/PEMSD8.yaml"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "STID-BJTaxi-InFlow",
|
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
|
|
@ -37,7 +31,7 @@
|
||||||
"args": "--config ./config/STID/BJTaxi_Inflow.yaml"
|
"args": "--config ./config/STID/BJTaxi_Inflow.yaml"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "STID-BJTaxi-OutFlow",
|
"name": "STID: BJTaxi-OutFlow",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
|
|
@ -45,7 +39,7 @@
|
||||||
"args": "--config ./config/STID/BJTaxi_Outflow.yaml"
|
"args": "--config ./config/STID/BJTaxi_Outflow.yaml"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "STID-NYCBike-InFlow",
|
"name": "STID: NYCBike-InFlow",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
|
|
@ -53,7 +47,7 @@
|
||||||
"args": "--config ./config/STID/NYCBike_Inflow.yaml"
|
"args": "--config ./config/STID/NYCBike_Inflow.yaml"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "STID-NYCBike-OutFlow",
|
"name": "STID: NYCBike-OutFlow",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
|
|
@ -61,15 +55,25 @@
|
||||||
"args": "--config ./config/STID/NYCBike_Outflow.yaml"
|
"args": "--config ./config/STID/NYCBike_Outflow.yaml"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "STID-SolarEnergy",
|
"name": "STID: SolarEnergy",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": "--config ./config/STID/SolarEnergy.yaml"
|
"args": "--config ./config/STID/SolarEnergy.yaml"
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// REPST 模型组
|
||||||
{
|
{
|
||||||
"name": "REPST-BJTaxi-InFlow",
|
"name": "REPST: PEMSD8",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": "--config ./config/REPST/PEMSD8.yaml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "REPST: BJTaxi-InFlow",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
|
|
@ -77,7 +81,7 @@
|
||||||
"args": "--config ./config/REPST/BJTaxi-Inflow.yaml"
|
"args": "--config ./config/REPST/BJTaxi-Inflow.yaml"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "REPST-NYCBike-outflow",
|
"name": "REPST: NYCBike-outflow",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
|
|
@ -85,7 +89,7 @@
|
||||||
"args": "--config ./config/REPST/NYCBike-outflow.yaml"
|
"args": "--config ./config/REPST/NYCBike-outflow.yaml"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "REPST-NYCBike-inflow",
|
"name": "REPST: NYCBike-inflow",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
|
|
@ -93,7 +97,7 @@
|
||||||
"args": "--config ./config/REPST/NYCBike-inflow.yaml"
|
"args": "--config ./config/REPST/NYCBike-inflow.yaml"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "REPST-PEMSBAY",
|
"name": "REPST: PEMS-BAY",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
|
|
@ -101,7 +105,7 @@
|
||||||
"args": "--config ./config/REPST/PEMS-BAY.yaml"
|
"args": "--config ./config/REPST/PEMS-BAY.yaml"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "REPST-METR",
|
"name": "REPST: METR-LA",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
|
|
@ -109,7 +113,7 @@
|
||||||
"args": "--config ./config/REPST/METR-LA.yaml"
|
"args": "--config ./config/REPST/METR-LA.yaml"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "REPST-Solar",
|
"name": "REPST: SolarEnergy",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
|
|
@ -117,7 +121,7 @@
|
||||||
"args": "--config ./config/REPST/SolarEnergy.yaml"
|
"args": "--config ./config/REPST/SolarEnergy.yaml"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "BeijingAirQuality",
|
"name": "REPST: BeijingAirQuality",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
|
|
@ -125,20 +129,78 @@
|
||||||
"args": "--config ./config/REPST/BeijingAirQuality.yaml"
|
"args": "--config ./config/REPST/BeijingAirQuality.yaml"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "AirQuality",
|
"name": "REPST: AirQuality",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": "--config ./config/REPST/AirQuality.yaml"
|
"args": "--config ./config/REPST/AirQuality.yaml"
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// AEPSA 模型组
|
||||||
{
|
{
|
||||||
"name": "AEPSA-PEMSBAY",
|
"name": "AEPSA: PEMS-BAY",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": "--config ./config/AEPSA/PEMS-BAY.yaml"
|
"args": "--config ./config/AEPSA/PEMS-BAY.yaml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "AEPSA: METR-LA",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": "--config ./config/AEPSA/METR-LA.yaml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "AEPSA: AirQuality",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": "--config ./config/AEPSA/AirQuality.yaml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "AEPSA: BJTaxi-Inflow",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": "--config ./config/AEPSA/BJTaxi-Inflow.yaml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "AEPSA: BJTaxi-outflow",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": "--config ./config/AEPSA/BJTaxi-outflow.yaml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "AEPSA: NYCBike-inflow",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": "--config ./config/AEPSA/NYCBike-inflow.yaml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "AEPSA: NYCBike-outflow",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": "--config ./config/AEPSA/NYCBike-outflow.yaml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "AEPSA: SolarEnergy",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": "--config ./config/AEPSA/SolarEnergy.yaml"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
basic:
|
||||||
|
dataset: "AirQuality"
|
||||||
|
mode : "train"
|
||||||
|
device : "cuda:0"
|
||||||
|
model: "AEPSA"
|
||||||
|
seed: 2023
|
||||||
|
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
column_wise: false
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
horizon: 24
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 35
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_ratio: 0.2
|
||||||
|
sample: 1
|
||||||
|
input_dim: 6
|
||||||
|
batch_size: 16
|
||||||
|
|
||||||
|
model:
|
||||||
|
pred_len: 24
|
||||||
|
seq_len: 24
|
||||||
|
patch_len: 6
|
||||||
|
stride: 7
|
||||||
|
dropout: 0.2
|
||||||
|
gpt_layers: 9
|
||||||
|
d_ff: 128
|
||||||
|
gpt_path: ./GPT-2
|
||||||
|
d_model: 64
|
||||||
|
n_heads: 1
|
||||||
|
input_dim: 6
|
||||||
|
word_num: 1000
|
||||||
|
|
||||||
|
train:
|
||||||
|
batch_size: 16
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
loss_func: mae
|
||||||
|
lr_decay: true
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
lr_init: 0.003
|
||||||
|
max_grad_norm: 5
|
||||||
|
weight_decay: 0
|
||||||
|
debug: false
|
||||||
|
output_dim: 6
|
||||||
|
log_step: 100
|
||||||
|
plot: false
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
basic:
|
basic:
|
||||||
dataset: "PEMSD8"
|
dataset: "BJTaxi-Inflow"
|
||||||
mode : "train"
|
mode : "train"
|
||||||
device : "cuda:0"
|
device : "cuda:0"
|
||||||
model: "AEPSA"
|
model: "AEPSA"
|
||||||
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
add_day_in_week: true
|
add_day_in_week: true
|
||||||
|
|
@ -13,14 +14,14 @@ data:
|
||||||
horizon: 12
|
horizon: 12
|
||||||
lag: 12
|
lag: 12
|
||||||
normalizer: std
|
normalizer: std
|
||||||
num_nodes: 170
|
num_nodes: 142
|
||||||
steps_per_day: 288
|
steps_per_day: 48
|
||||||
test_ratio: 0.2
|
test_ratio: 0.2
|
||||||
tod: false
|
tod: false
|
||||||
val_ratio: 0.2
|
val_ratio: 0.2
|
||||||
sample: 1
|
sample: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
batch_size: 64
|
batch_size: 32
|
||||||
|
|
||||||
model:
|
model:
|
||||||
pred_len: 12
|
pred_len: 12
|
||||||
|
|
@ -33,9 +34,11 @@ model:
|
||||||
gpt_path: ./GPT-2
|
gpt_path: ./GPT-2
|
||||||
d_model: 64
|
d_model: 64
|
||||||
n_heads: 1
|
n_heads: 1
|
||||||
|
input_dim: 1
|
||||||
|
word_num: 1000
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 64
|
batch_size: 32
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
epochs: 100
|
epochs: 100
|
||||||
|
|
@ -46,13 +49,10 @@ train:
|
||||||
lr_decay_step: "5,20,40,70"
|
lr_decay_step: "5,20,40,70"
|
||||||
lr_init: 0.003
|
lr_init: 0.003
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
real_value: true
|
|
||||||
seed: 12
|
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
debug: false
|
debug: false
|
||||||
output_dim: 1
|
output_dim: 1
|
||||||
log_step: 2000
|
log_step: 100
|
||||||
plot: false
|
plot: false
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
basic:
|
basic:
|
||||||
dataset: "PEMSD8"
|
dataset: "BJTaxi-outflow"
|
||||||
mode : "train"
|
mode : "train"
|
||||||
device : "cuda:0"
|
device : "cuda:0"
|
||||||
model: "REPST"
|
model: "AEPSA"
|
||||||
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
add_day_in_week: true
|
add_day_in_week: true
|
||||||
|
|
@ -13,14 +14,14 @@ data:
|
||||||
horizon: 12
|
horizon: 12
|
||||||
lag: 12
|
lag: 12
|
||||||
normalizer: std
|
normalizer: std
|
||||||
num_nodes: 170
|
num_nodes: 142
|
||||||
steps_per_day: 288
|
steps_per_day: 48
|
||||||
test_ratio: 0.2
|
test_ratio: 0.2
|
||||||
tod: false
|
tod: false
|
||||||
val_ratio: 0.2
|
val_ratio: 0.2
|
||||||
sample: 1
|
sample: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
batch_size: 64
|
batch_size: 32
|
||||||
|
|
||||||
model:
|
model:
|
||||||
pred_len: 12
|
pred_len: 12
|
||||||
|
|
@ -33,9 +34,11 @@ model:
|
||||||
gpt_path: ./GPT-2
|
gpt_path: ./GPT-2
|
||||||
d_model: 64
|
d_model: 64
|
||||||
n_heads: 1
|
n_heads: 1
|
||||||
|
input_dim: 1
|
||||||
|
word_num: 1000
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 64
|
batch_size: 32
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
epochs: 100
|
epochs: 100
|
||||||
|
|
@ -46,13 +49,10 @@ train:
|
||||||
lr_decay_step: "5,20,40,70"
|
lr_decay_step: "5,20,40,70"
|
||||||
lr_init: 0.003
|
lr_init: 0.003
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
real_value: true
|
|
||||||
seed: 12
|
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
debug: false
|
debug: false
|
||||||
output_dim: 1
|
output_dim: 1
|
||||||
log_step: 2000
|
log_step: 100
|
||||||
plot: false
|
plot: false
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
|
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
basic:
|
||||||
|
dataset: "METR-LA"
|
||||||
|
mode : "train"
|
||||||
|
device : "cuda:0"
|
||||||
|
model: "AEPSA"
|
||||||
|
seed: 2023
|
||||||
|
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
column_wise: false
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
horizon: 24
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 207
|
||||||
|
steps_per_day: 288
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_ratio: 0.2
|
||||||
|
sample: 1
|
||||||
|
input_dim: 1
|
||||||
|
batch_size: 16
|
||||||
|
|
||||||
|
model:
|
||||||
|
pred_len: 24
|
||||||
|
seq_len: 24
|
||||||
|
patch_len: 6
|
||||||
|
stride: 7
|
||||||
|
dropout: 0.2
|
||||||
|
gpt_layers: 9
|
||||||
|
d_ff: 128
|
||||||
|
gpt_path: ./GPT-2
|
||||||
|
d_model: 64
|
||||||
|
n_heads: 1
|
||||||
|
input_dim: 1
|
||||||
|
word_num: 1000
|
||||||
|
|
||||||
|
train:
|
||||||
|
batch_size: 16
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
loss_func: mae
|
||||||
|
lr_decay: true
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
lr_init: 0.003
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
debug: false
|
||||||
|
output_dim: 1
|
||||||
|
log_step: 1000
|
||||||
|
plot: false
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
basic:
|
||||||
|
dataset: "NYCBike-inflow"
|
||||||
|
mode : "train"
|
||||||
|
device : "cuda:0"
|
||||||
|
model: "AEPSA"
|
||||||
|
seed: 2023
|
||||||
|
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
column_wise: false
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
horizon: 12
|
||||||
|
lag: 12
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 200
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_ratio: 0.2
|
||||||
|
sample: 1
|
||||||
|
input_dim: 1
|
||||||
|
batch_size: 32
|
||||||
|
|
||||||
|
model:
|
||||||
|
pred_len: 12
|
||||||
|
seq_len: 12
|
||||||
|
patch_len: 6
|
||||||
|
stride: 7
|
||||||
|
dropout: 0.2
|
||||||
|
gpt_layers: 9
|
||||||
|
d_ff: 128
|
||||||
|
gpt_path: ./GPT-2
|
||||||
|
d_model: 64
|
||||||
|
n_heads: 1
|
||||||
|
input_dim: 1
|
||||||
|
word_num: 1000
|
||||||
|
|
||||||
|
train:
|
||||||
|
batch_size: 32
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
loss_func: mae
|
||||||
|
lr_decay: true
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
lr_init: 0.003
|
||||||
|
max_grad_norm: 5
|
||||||
|
weight_decay: 0
|
||||||
|
debug: false
|
||||||
|
output_dim: 1
|
||||||
|
log_step: 100
|
||||||
|
plot: false
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
basic:
|
||||||
|
dataset: "NYCBike-outflow"
|
||||||
|
mode : "train"
|
||||||
|
device : "cuda:0"
|
||||||
|
model: "AEPSA"
|
||||||
|
seed: 2023
|
||||||
|
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
column_wise: false
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
horizon: 12
|
||||||
|
lag: 12
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 200
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_ratio: 0.2
|
||||||
|
sample: 1
|
||||||
|
input_dim: 1
|
||||||
|
batch_size: 32
|
||||||
|
|
||||||
|
model:
|
||||||
|
pred_len: 12
|
||||||
|
seq_len: 12
|
||||||
|
patch_len: 6
|
||||||
|
stride: 7
|
||||||
|
dropout: 0.2
|
||||||
|
gpt_layers: 9
|
||||||
|
d_ff: 128
|
||||||
|
gpt_path: ./GPT-2
|
||||||
|
d_model: 64
|
||||||
|
n_heads: 1
|
||||||
|
input_dim: 1
|
||||||
|
word_num: 1000
|
||||||
|
|
||||||
|
train:
|
||||||
|
batch_size: 32
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
loss_func: mae
|
||||||
|
lr_decay: true
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
lr_init: 0.003
|
||||||
|
max_grad_norm: 5
|
||||||
|
weight_decay: 0
|
||||||
|
debug: false
|
||||||
|
output_dim: 1
|
||||||
|
log_step: 100
|
||||||
|
plot: false
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
|
@ -3,6 +3,7 @@ basic:
|
||||||
mode : "train"
|
mode : "train"
|
||||||
device : "cuda:0"
|
device : "cuda:0"
|
||||||
model: "AEPSA"
|
model: "AEPSA"
|
||||||
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
add_day_in_week: true
|
add_day_in_week: true
|
||||||
|
|
@ -34,6 +35,7 @@ model:
|
||||||
d_model: 64
|
d_model: 64
|
||||||
n_heads: 1
|
n_heads: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
|
word_num: 1000
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
@ -47,8 +49,6 @@ train:
|
||||||
lr_decay_step: "5,20,40,70"
|
lr_decay_step: "5,20,40,70"
|
||||||
lr_init: 0.003
|
lr_init: 0.003
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
real_value: true
|
|
||||||
seed: 12
|
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
debug: false
|
debug: false
|
||||||
output_dim: 1
|
output_dim: 1
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
basic:
|
||||||
|
dataset: "SolarEnergy"
|
||||||
|
mode : "train"
|
||||||
|
device : "cuda:0"
|
||||||
|
model: "AEPSA"
|
||||||
|
seed: 2023
|
||||||
|
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
column_wise: false
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
horizon: 24
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 137
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_ratio: 0.2
|
||||||
|
sample: 1
|
||||||
|
input_dim: 1
|
||||||
|
batch_size: 64
|
||||||
|
|
||||||
|
model:
|
||||||
|
pred_len: 24
|
||||||
|
seq_len: 24
|
||||||
|
patch_len: 6
|
||||||
|
stride: 7
|
||||||
|
dropout: 0.2
|
||||||
|
gpt_layers: 9
|
||||||
|
d_ff: 128
|
||||||
|
gpt_path: ./GPT-2
|
||||||
|
d_model: 64
|
||||||
|
n_heads: 1
|
||||||
|
input_dim: 1
|
||||||
|
word_num: 1000
|
||||||
|
num_nodes: 137
|
||||||
|
|
||||||
|
train:
|
||||||
|
batch_size: 64
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
loss_func: mae
|
||||||
|
lr_decay: true
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
lr_init: 0.003
|
||||||
|
max_grad_norm: 5
|
||||||
|
weight_decay: 0
|
||||||
|
debug: false
|
||||||
|
output_dim: 1
|
||||||
|
log_step: 100
|
||||||
|
plot: false
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
|
@ -37,6 +37,7 @@ model:
|
||||||
input_dim: 6
|
input_dim: 6
|
||||||
output_dim: 3
|
output_dim: 3
|
||||||
word_num: 1000
|
word_num: 1000
|
||||||
|
num_nodes: 35
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ model:
|
||||||
n_heads: 1
|
n_heads: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
word_num: 1000
|
word_num: 1000
|
||||||
|
num_nodes: 1024
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ model:
|
||||||
n_heads: 1
|
n_heads: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
word_num: 1000
|
word_num: 1000
|
||||||
|
num_nodes: 1024
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ model:
|
||||||
input_dim: 3
|
input_dim: 3
|
||||||
output_dim: 3
|
output_dim: 3
|
||||||
word_num: 1000
|
word_num: 1000
|
||||||
|
num_nodes: 7
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ model:
|
||||||
n_heads: 1
|
n_heads: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
word_num: 1000
|
word_num: 1000
|
||||||
|
num_nodes: 207
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ model:
|
||||||
n_heads: 1
|
n_heads: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
word_num: 1000
|
word_num: 1000
|
||||||
|
num_nodes: 128
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ model:
|
||||||
n_heads: 1
|
n_heads: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
word_num: 1000
|
word_num: 1000
|
||||||
|
num_nodes: 128
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ model:
|
||||||
n_heads: 1
|
n_heads: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
word_num: 1000
|
word_num: 1000
|
||||||
|
num_nodes: 325
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ model:
|
||||||
n_heads: 1
|
n_heads: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
t_max: 5
|
t_max: 5
|
||||||
|
num_nodes: 325
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ data:
|
||||||
val_ratio: 0.2
|
val_ratio: 0.2
|
||||||
sample: 1
|
sample: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
batch_size: 16
|
batch_size: 64
|
||||||
|
|
||||||
model:
|
model:
|
||||||
pred_len: 24
|
pred_len: 24
|
||||||
|
|
@ -36,9 +36,10 @@ model:
|
||||||
n_heads: 1
|
n_heads: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
word_num: 1000
|
word_num: 1000
|
||||||
|
num_nodes: 137
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 64
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
epochs: 100
|
epochs: 100
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,251 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||||
|
from einops import rearrange
|
||||||
|
from model.AEPSA.normalizer import GumbelSoftmax
|
||||||
|
from model.AEPSA.reprogramming import PatchEmbedding, ReprogrammingLayer
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class DynamicGraphEnhancer(nn.Module):
|
||||||
|
"""
|
||||||
|
动态图增强器,基于节点嵌入自动生成图结构
|
||||||
|
参考DDGCRN的设计,使用节点嵌入和特征信息动态计算邻接矩阵
|
||||||
|
"""
|
||||||
|
def __init__(self, num_nodes, in_dim, embed_dim=10):
|
||||||
|
super().__init__()
|
||||||
|
self.num_nodes = num_nodes
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
# 节点嵌入参数
|
||||||
|
self.node_embeddings = nn.Parameter(
|
||||||
|
torch.randn(num_nodes, embed_dim), requires_grad=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 特征转换层,用于生成动态调整的嵌入
|
||||||
|
self.feature_transform = nn.Sequential(
|
||||||
|
nn.Linear(in_dim, 16),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
nn.Linear(16, 2),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
nn.Linear(2, embed_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册单位矩阵作为固定的支持矩阵
|
||||||
|
self.register_buffer("eye", torch.eye(num_nodes))
|
||||||
|
|
||||||
|
def get_laplacian(self, graph, I, normalize=True):
|
||||||
|
"""
|
||||||
|
计算归一化拉普拉斯矩阵
|
||||||
|
"""
|
||||||
|
# 计算度矩阵的逆平方根
|
||||||
|
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
|
||||||
|
D_inv[torch.isinf(D_inv)] = 0.0 # 处理零除问题
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
return torch.matmul(torch.matmul(D_inv, graph), D_inv)
|
||||||
|
else:
|
||||||
|
return torch.matmul(torch.matmul(D_inv, graph + I), D_inv)
|
||||||
|
|
||||||
|
def forward(self, X):
|
||||||
|
"""
|
||||||
|
X: 输入特征 [B, N, D]
|
||||||
|
返回: 动态生成的归一化拉普拉斯矩阵 [B, N, N]
|
||||||
|
"""
|
||||||
|
batch_size = X.size(0)
|
||||||
|
laplacians = []
|
||||||
|
|
||||||
|
# 获取单位矩阵
|
||||||
|
I = self.eye.to(X.device)
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
# 使用特征转换层生成动态嵌入调整因子
|
||||||
|
filt = self.feature_transform(X[b]) # [N, embed_dim]
|
||||||
|
|
||||||
|
# 计算节点嵌入向量
|
||||||
|
nodevec = torch.tanh(self.node_embeddings * filt)
|
||||||
|
|
||||||
|
# 通过节点嵌入的点积计算邻接矩阵
|
||||||
|
adj = F.relu(torch.matmul(nodevec, nodevec.transpose(0, 1)))
|
||||||
|
|
||||||
|
# 计算归一化拉普拉斯矩阵
|
||||||
|
laplacian = self.get_laplacian(adj, I)
|
||||||
|
laplacians.append(laplacian)
|
||||||
|
|
||||||
|
return torch.stack(laplacians, dim=0)
|
||||||
|
|
||||||
|
class GraphEnhancedEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
基于Chebyshev多项式和动态拉普拉斯矩阵的图增强编码器
|
||||||
|
用于将动态图结构信息整合到特征编码中
|
||||||
|
"""
|
||||||
|
def __init__(self, K=3, in_dim=64, hidden_dim=32, num_nodes=325, embed_dim=10, device='cpu'):
|
||||||
|
super().__init__()
|
||||||
|
self.K = K # Chebyshev多项式阶数
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# 动态图增强器
|
||||||
|
self.graph_enhancer = DynamicGraphEnhancer(num_nodes, in_dim, embed_dim)
|
||||||
|
|
||||||
|
# 谱系数 α_k (可学习参数)
|
||||||
|
self.alpha = nn.Parameter(torch.randn(K + 1, 1))
|
||||||
|
|
||||||
|
# 传播权重 W_k (可学习参数)
|
||||||
|
self.W = nn.ParameterList([
|
||||||
|
nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.to(device)
|
||||||
|
|
||||||
|
def chebyshev_polynomials(self, L_tilde, X):
|
||||||
|
"""递归计算 [T_0(L_tilde)X, ..., T_K(L_tilde)X]"""
|
||||||
|
T_k_list = [X]
|
||||||
|
if self.K >= 1:
|
||||||
|
T_k_list.append(torch.matmul(L_tilde, X))
|
||||||
|
for k in range(2, self.K + 1):
|
||||||
|
T_k_list.append(2 * torch.matmul(L_tilde, T_k_list[-1]) - T_k_list[-2])
|
||||||
|
return T_k_list
|
||||||
|
|
||||||
|
def forward(self, X):
|
||||||
|
"""
|
||||||
|
X: 输入特征 [B, N, D]
|
||||||
|
返回: 增强后的特征 [B, N, hidden_dim*(K+1)]
|
||||||
|
"""
|
||||||
|
batch_size, num_nodes, _ = X.shape
|
||||||
|
enhanced_features = []
|
||||||
|
|
||||||
|
# 动态生成拉普拉斯矩阵
|
||||||
|
laplacians = self.graph_enhancer(X)
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
L = laplacians[b]
|
||||||
|
|
||||||
|
# 特征值缩放
|
||||||
|
try:
|
||||||
|
lambda_max = torch.linalg.eigvalsh(L).max().real
|
||||||
|
# 避免除零问题
|
||||||
|
if lambda_max < 1e-6:
|
||||||
|
lambda_max = 1.0
|
||||||
|
L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device)
|
||||||
|
except:
|
||||||
|
# 如果计算特征值失败,使用单位矩阵
|
||||||
|
L_tilde = torch.eye(num_nodes, device=X.device)
|
||||||
|
|
||||||
|
# 计算Chebyshev多项式展开
|
||||||
|
T_k_list = self.chebyshev_polynomials(L_tilde, X[b])
|
||||||
|
H_list = []
|
||||||
|
|
||||||
|
# 应用传播权重
|
||||||
|
for k in range(self.K + 1):
|
||||||
|
H_k = torch.matmul(T_k_list[k], self.W[k])
|
||||||
|
H_list.append(H_k)
|
||||||
|
|
||||||
|
# 拼接所有尺度的特征
|
||||||
|
X_enhanced = torch.cat(H_list, dim=-1) # [N, hidden_dim*(K+1)]
|
||||||
|
enhanced_features.append(X_enhanced)
|
||||||
|
|
||||||
|
return torch.stack(enhanced_features, dim=0)
|
||||||
|
|
||||||
|
class AEPSA(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, configs):
|
||||||
|
super(AEPSA, self).__init__()
|
||||||
|
self.device = configs['device']
|
||||||
|
self.pred_len = configs['pred_len']
|
||||||
|
self.seq_len = configs['seq_len']
|
||||||
|
self.patch_len = configs['patch_len']
|
||||||
|
self.input_dim = configs['input_dim']
|
||||||
|
self.stride = configs['stride']
|
||||||
|
self.dropout = configs['dropout']
|
||||||
|
self.gpt_layers = configs['gpt_layers']
|
||||||
|
self.d_ff = configs['d_ff']
|
||||||
|
self.gpt_path = configs['gpt_path']
|
||||||
|
self.num_nodes = configs.get('num_nodes', 325) # 添加节点数量配置
|
||||||
|
|
||||||
|
self.word_choice = GumbelSoftmax(configs['word_num'])
|
||||||
|
|
||||||
|
self.d_model = configs['d_model']
|
||||||
|
self.n_heads = configs['n_heads']
|
||||||
|
self.d_keys = None
|
||||||
|
self.d_llm = 768
|
||||||
|
|
||||||
|
self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2)
|
||||||
|
self.head_nf = self.d_ff * self.patch_nums
|
||||||
|
|
||||||
|
# 词嵌入
|
||||||
|
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
|
||||||
|
|
||||||
|
# GPT2初始化
|
||||||
|
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
|
||||||
|
self.gpts.h = self.gpts.h[:self.gpt_layers]
|
||||||
|
self.gpts.apply(self.reset_parameters)
|
||||||
|
|
||||||
|
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
|
||||||
|
self.vocab_size = self.word_embeddings.shape[0]
|
||||||
|
self.mapping_layer = nn.Linear(self.vocab_size, 1)
|
||||||
|
self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm)
|
||||||
|
|
||||||
|
# 添加动态图增强编码器
|
||||||
|
self.graph_encoder = GraphEnhancedEncoder(
|
||||||
|
K=configs.get('chebyshev_order', 3),
|
||||||
|
in_dim=self.d_model,
|
||||||
|
hidden_dim=configs.get('graph_hidden_dim', 32),
|
||||||
|
num_nodes=self.num_nodes,
|
||||||
|
embed_dim=configs.get('graph_embed_dim', 10),
|
||||||
|
device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# 特征融合层
|
||||||
|
self.feature_fusion = nn.Linear(
|
||||||
|
self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1),
|
||||||
|
self.d_model
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_mlp = nn.Sequential(
|
||||||
|
nn.Linear(self.d_llm, 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(128, self.pred_len)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, (name, param) in enumerate(self.gpts.named_parameters()):
|
||||||
|
if 'wpe' in name:
|
||||||
|
param.requires_grad = True
|
||||||
|
else:
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def reset_parameters(self, module):
|
||||||
|
if hasattr(module, 'weight') and module.weight is not None:
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
if hasattr(module, 'bias') and module.bias is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
x: 输入数据 [B, T, N, C]
|
||||||
|
自动生成图结构,无需外部提供邻接矩阵
|
||||||
|
"""
|
||||||
|
x = x[..., :1]
|
||||||
|
x_enc = rearrange(x, 'b t n c -> b n c t')
|
||||||
|
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C)
|
||||||
|
# 应用图增强编码器(自动生成图结构)
|
||||||
|
graph_enhanced = self.graph_encoder(enc_out)
|
||||||
|
# 保持相同的维度
|
||||||
|
|
||||||
|
# 特征融合 - 现在两个张量都是三维的 [B, N, d_model]
|
||||||
|
enc_out = torch.cat([enc_out, graph_enhanced], dim=-1)
|
||||||
|
enc_out = self.feature_fusion(enc_out)
|
||||||
|
|
||||||
|
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
||||||
|
masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0))
|
||||||
|
source_embeddings = self.word_embeddings[masks==1]
|
||||||
|
|
||||||
|
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
||||||
|
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
|
||||||
|
|
||||||
|
dec_out = self.out_mlp(enc_out)
|
||||||
|
outputs = dec_out.unsqueeze(dim=-1)
|
||||||
|
outputs = outputs.repeat(1, 1, 1, n_vars)
|
||||||
|
outputs = outputs.permute(0,2,1,3)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
@ -1,103 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
|
||||||
from einops import rearrange
|
|
||||||
from model.REPST.normalizer import GumbelSoftmax
|
|
||||||
from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer
|
|
||||||
|
|
||||||
class repst(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, configs):
|
|
||||||
super(repst, self).__init__()
|
|
||||||
self.device = configs['device']
|
|
||||||
self.pred_len = configs['pred_len']
|
|
||||||
self.seq_len = configs['seq_len']
|
|
||||||
self.patch_len = configs['patch_len']
|
|
||||||
self.input_dim = configs['input_dim']
|
|
||||||
self.stride = configs['stride']
|
|
||||||
self.dropout = configs['dropout']
|
|
||||||
self.gpt_layers = configs['gpt_layers']
|
|
||||||
self.d_ff = configs['d_ff']
|
|
||||||
self.gpt_path = configs['gpt_path']
|
|
||||||
|
|
||||||
self.word_choice = GumbelSoftmax(configs['word_num'])
|
|
||||||
|
|
||||||
self.d_model = configs['d_model']
|
|
||||||
self.n_heads = configs['n_heads']
|
|
||||||
self.d_keys = None
|
|
||||||
self.d_llm = 768
|
|
||||||
|
|
||||||
self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2)
|
|
||||||
self.head_nf = self.d_ff * self.patch_nums
|
|
||||||
|
|
||||||
# 词嵌入
|
|
||||||
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
|
|
||||||
|
|
||||||
# GPT2初始化
|
|
||||||
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
|
|
||||||
self.gpts.h = self.gpts.h[:self.gpt_layers]
|
|
||||||
self.gpts.apply(self.reset_parameters)
|
|
||||||
|
|
||||||
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
|
|
||||||
self.vocab_size = self.word_embeddings.shape[0]
|
|
||||||
self.mapping_layer = nn.Linear(self.vocab_size, 1)
|
|
||||||
self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm)
|
|
||||||
|
|
||||||
self.out_mlp = nn.Sequential(
|
|
||||||
nn.Linear(self.d_llm, 128),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(128, self.pred_len)
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, (name, param) in enumerate(self.gpts.named_parameters()):
|
|
||||||
if 'wpe' in name:
|
|
||||||
param.requires_grad = True
|
|
||||||
else:
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
def reset_parameters(self, module):
|
|
||||||
if hasattr(module, 'weight') and module.weight is not None:
|
|
||||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
||||||
if hasattr(module, 'bias') and module.bias is not None:
|
|
||||||
torch.nn.init.zeros_(module.bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x[..., :1]
|
|
||||||
x_enc = rearrange(x, 'b t n c -> b n c t')
|
|
||||||
enc_out, n_vars = self.patch_embedding(x_enc)
|
|
||||||
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
|
||||||
masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0))
|
|
||||||
source_embeddings = self.word_embeddings[masks==1]
|
|
||||||
|
|
||||||
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
|
||||||
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
|
|
||||||
|
|
||||||
dec_out = self.out_mlp(enc_out)
|
|
||||||
outputs = dec_out.unsqueeze(dim=-1)
|
|
||||||
outputs = outputs.repeat(1, 1, 1, n_vars)
|
|
||||||
outputs = outputs.permute(0,2,1,3)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
configs = {
|
|
||||||
'device': 'cuda:0',
|
|
||||||
'pred_len': 24,
|
|
||||||
'seq_len': 24,
|
|
||||||
'patch_len': 6,
|
|
||||||
'stride': 7,
|
|
||||||
'dropout': 0.2,
|
|
||||||
'gpt_layers': 9,
|
|
||||||
'd_ff': 128,
|
|
||||||
'gpt_path': './GPT-2',
|
|
||||||
'd_model': 64,
|
|
||||||
'n_heads': 1,
|
|
||||||
'input_dim': 1
|
|
||||||
}
|
|
||||||
model = repst(configs)
|
|
||||||
x = torch.randn(16, 24, 325, 1)
|
|
||||||
y = model(x)
|
|
||||||
|
|
||||||
print(y.shape)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -23,7 +23,7 @@ from model.ST_SSL.ST_SSL import STSSLModel
|
||||||
from model.STGNRDE.Make_model import make_model as make_nrde_model
|
from model.STGNRDE.Make_model import make_model as make_nrde_model
|
||||||
from model.STAWnet.STAWnet import STAWnet
|
from model.STAWnet.STAWnet import STAWnet
|
||||||
from model.REPST.repst import repst as REPST
|
from model.REPST.repst import repst as REPST
|
||||||
from model.AEPSA.repst import repst as AEPSA
|
from model.AEPSA.aepsa import AEPSA as AEPSA
|
||||||
|
|
||||||
|
|
||||||
def model_selector(config):
|
def model_selector(config):
|
||||||
|
|
|
||||||
|
|
@ -160,6 +160,7 @@ def check_and_download_data():
|
||||||
file_path = f"Datasets/TaxiBJ/{file}"
|
file_path = f"Datasets/TaxiBJ/{file}"
|
||||||
download_github_data(file_path, taxi_bj_floder)
|
download_github_data(file_path, taxi_bj_floder)
|
||||||
# 下载后更新缺失列表
|
# 下载后更新缺失列表
|
||||||
|
# download_and_extract("http://code.zhang-heng.com/static/BeijingTaxi.7z", data_dir)
|
||||||
missing_list = detect_data_integrity(data_dir, file_tree)
|
missing_list = detect_data_integrity(data_dir, file_tree)
|
||||||
|
|
||||||
# 检查并下载TaxiBJ数据
|
# 检查并下载TaxiBJ数据
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue