REPST #3

Merged
czzhangheng merged 42 commits from REPST into main 2025-12-20 16:03:22 +08:00
23 changed files with 517 additions and 120 deletions
Showing only changes of commit 56b09ea8ac - Show all commits

View File

@ -2,7 +2,7 @@ basic:
dataset: METR-LA dataset: METR-LA
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA_v2 model: ASTRA_v2
seed: 2023 seed: 2023
data: data:

View File

@ -2,7 +2,7 @@ basic:
dataset: PEMS-BAY dataset: PEMS-BAY
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA_v3 model: ASTRA_v2
seed: 2023 seed: 2023
data: data:

View File

@ -2,7 +2,7 @@ basic:
dataset: SolarEnergy dataset: SolarEnergy
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA_v2 model: ASTRA_v2
seed: 2023 seed: 2023
data: data:

View File

@ -0,0 +1,54 @@
basic:
dataset: AirQuality
device: cuda:0
mode: train
model: ASTRA_v3
seed: 2023
data:
batch_size: 16
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 6
lag: 24
normalizer: std
num_nodes: 35
steps_per_day: 24
test_ratio: 0.2
val_ratio: 0.2
model:
d_ff: 128
d_model: 64
dropout: 0.2
gpt_layers: 9
gpt_path: ./GPT-2
input_dim: 6
n_heads: 1
num_nodes: 35
patch_len: 6
pred_len: 24
seq_len: 24
stride: 7
word_num: 1000
train:
batch_size: 16
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 100
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 6
plot: false
weight_decay: 0

View File

@ -0,0 +1,54 @@
basic:
dataset: BJTaxi-InFlow
device: cuda:0
mode: train
model: ASTRA_v3
seed: 2023
data:
batch_size: 32
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 1024
steps_per_day: 48
test_ratio: 0.2
val_ratio: 0.2
model:
d_ff: 128
d_model: 64
dropout: 0.2
gpt_layers: 9
gpt_path: ./GPT-2
input_dim: 1
n_heads: 1
num_nodes: 1024
patch_len: 6
pred_len: 24
seq_len: 24
stride: 7
word_num: 1000
train:
batch_size: 32
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 100
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
plot: false
weight_decay: 0

View File

@ -0,0 +1,54 @@
basic:
dataset: BJTaxi-OutFlow
device: cuda:0
mode: train
model: ASTRA_v3
seed: 2023
data:
batch_size: 32
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 1024
steps_per_day: 48
test_ratio: 0.2
val_ratio: 0.2
model:
d_ff: 128
d_model: 64
dropout: 0.2
gpt_layers: 9
gpt_path: ./GPT-2
input_dim: 1
n_heads: 1
num_nodes: 1024
patch_len: 6
pred_len: 24
seq_len: 24
stride: 7
word_num: 1000
train:
batch_size: 32
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 100
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
plot: false
weight_decay: 0

View File

@ -2,7 +2,7 @@ basic:
dataset: METR-LA dataset: METR-LA
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA_v3 model: ASTRA_v3
seed: 2023 seed: 2023
data: data:
@ -19,11 +19,9 @@ data:
val_ratio: 0.2 val_ratio: 0.2
model: model:
chebyshev_order: 3
d_ff: 128 d_ff: 128
d_model: 64 d_model: 64
dropout: 0.2 dropout: 0.2
graph_hidden_dim: 32
gpt_layers: 9 gpt_layers: 9
gpt_path: ./GPT-2 gpt_path: ./GPT-2
input_dim: 1 input_dim: 1

View File

@ -0,0 +1,54 @@
basic:
dataset: NYCBike-InFlow
device: cuda:0
mode: train
model: ASTRA_v3
seed: 2023
data:
batch_size: 32
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 128
steps_per_day: 48
test_ratio: 0.2
val_ratio: 0.2
model:
d_ff: 128
d_model: 64
dropout: 0.2
gpt_layers: 9
gpt_path: ./GPT-2
input_dim: 1
n_heads: 1
num_nodes: 128
patch_len: 6
pred_len: 24
seq_len: 24
stride: 7
word_num: 1000
train:
batch_size: 32
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 100
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
plot: false
weight_decay: 0

View File

@ -0,0 +1,54 @@
basic:
dataset: NYCBike-OutFlow
device: cuda:0
mode: train
model: ASTRA_v3
seed: 2023
data:
batch_size: 32
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 128
steps_per_day: 48
test_ratio: 0.2
val_ratio: 0.2
model:
d_ff: 128
d_model: 64
dropout: 0.2
gpt_layers: 9
gpt_path: ./GPT-2
input_dim: 1
n_heads: 1
num_nodes: 128
patch_len: 6
pred_len: 24
seq_len: 24
stride: 7
word_num: 1000
train:
batch_size: 32
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 100
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
plot: false
weight_decay: 0

View File

@ -2,7 +2,7 @@ basic:
dataset: PEMS-BAY dataset: PEMS-BAY
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA_v2 model: ASTRA_v3
seed: 2023 seed: 2023
data: data:
@ -19,11 +19,9 @@ data:
val_ratio: 0.2 val_ratio: 0.2
model: model:
chebyshev_order: 3
d_ff: 128 d_ff: 128
d_model: 64 d_model: 64
dropout: 0.2 dropout: 0.2
graph_hidden_dim: 32
gpt_layers: 9 gpt_layers: 9
gpt_path: ./GPT-2 gpt_path: ./GPT-2
input_dim: 1 input_dim: 1

View File

@ -0,0 +1,54 @@
basic:
dataset: SolarEnergy
device: cuda:0
mode: train
model: ASTRA_v3
seed: 2023
data:
batch_size: 64
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 137
steps_per_day: 24
test_ratio: 0.2
val_ratio: 0.2
model:
d_ff: 128
d_model: 64
dropout: 0.2
gpt_layers: 9
gpt_path: ./GPT-2
input_dim: 1
n_heads: 1
num_nodes: 137
patch_len: 6
pred_len: 24
seq_len: 24
stride: 7
word_num: 1000
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 100
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
plot: false
weight_decay: 0

View File

@ -13,8 +13,8 @@ data:
input_dim: 6 input_dim: 6
lag: 24 lag: 24
normalizer: std normalizer: std
num_nodes: 12 num_nodes: 35
steps_per_day: 288 steps_per_day: 24
test_ratio: 0.2 test_ratio: 0.2
val_ratio: 0.2 val_ratio: 0.2

View File

@ -184,7 +184,7 @@ class ASTRA(nn.Module):
def forward(self, x): def forward(self, x):
# 数据处理 # 数据处理
x = x[..., :1] # [B,T,N,1] x = x[..., :self.input_dim]
x_enc = rearrange(x, 'b t n c -> b n c t') # [B,N,1,T] x_enc = rearrange(x, 'b t n c -> b n c t') # [B,N,1,T]
# 图编码 # 图编码
@ -203,7 +203,9 @@ class ASTRA(nn.Module):
dec_out = self.out_mlp(X_enc) # [B,N,pred_len] dec_out = self.out_mlp(X_enc) # [B,N,pred_len]
# 维度调整 # 维度调整
outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1] dec_out = self.out_mlp(enc_out)
outputs = outputs.permute(0, 2, 1, 3) # [B,pred_len,N,1] outputs = dec_out.unsqueeze(dim=-1)
outputs = outputs.repeat(1, 1, 1, self.input_dim)
outputs = outputs.permute(0,2,1,3)
return outputs return outputs

View File

@ -60,25 +60,17 @@ def run(config):
case _: case _:
raise ValueError(f"Unsupported mode: {config['basic']['mode']}") raise ValueError(f"Unsupported mode: {config['basic']['mode']}")
def main(debug=False): def main(model, data, debug=False):
# 指定模型
model_list = ["iTransformer"]
# 指定数据集
# dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"]
# dataset_list = ["AirQuality"]
dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-Inflow", "NYCBike-Outflow"]
# 我的调试开关,不做测试就填 str(False) # 我的调试开关,不做测试就填 str(False)
# os.environ["TRY"] = str(False) # os.environ["TRY"] = str(False)
os.environ["TRY"] = str(False) os.environ["TRY"] = str(debug)
for model in model_list: for model in model_list:
for dataset in dataset_list: for dataset in data:
config_path = f"./config/{model}/{dataset}.yaml" config_path = f"./config/{model}/{dataset}.yaml"
# 可去这个函数里面调整统一的config项注意调设备epochs # 可去这个函数里面调整统一的config项注意调设备epochs
config = read_config(config_path) config = read_config(config_path)
print(f"\nRunning {model} on {dataset}") print(f"\nRunning {model} on {dataset}")
# print(f"config: {config}")
if os.environ.get("TRY") == "True": if os.environ.get("TRY") == "True":
try: try:
run(config) run(config)
@ -97,4 +89,9 @@ def main(debug=False):
if __name__ == "__main__": if __name__ == "__main__":
# 调试用 # 调试用
main(debug = True) # model_list = ["iTransformer", "PatchTST", "HI"]
model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"]
# dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"]
# dataset_list = ["AirQuality"]
dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"]
main(model_list, dataset_list, debug = True)

View File

@ -73,8 +73,12 @@ class Trainer:
out = self.ts.inverse(out, b, t, n, c) out = self.ts.inverse(out, b, t, n, c)
if os.environ.get("TRY") == "True": if os.environ.get("TRY") == "True":
print(out.shape, label.shape) if out.shape == label.shape:
assert out.shape == label.shape print("shape true")
assert False
else:
print("shape false")
assert False
loss = self.loss(out, label) loss = self.loss(out, label)
d_out = self.scaler.inverse_transform(out) d_out = self.scaler.inverse_transform(out)

View File

@ -1,106 +1,180 @@
import os import os, time, copy, torch
import time from tqdm import tqdm
import copy
import torch
from utils.logger import get_logger from utils.logger import get_logger
from utils.loss_function import all_metrics from utils.loss_function import all_metrics
from tqdm import tqdm
class Trainer: class Trainer:
def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): def __init__(self, model, loss, optimizer,
self.config, self.device, self.args = args, args["basic"]["device"], args["train"] train_loader, val_loader, test_loader,
self.model, self.loss, self.optimizer, self.lr_scheduler = model, loss, optimizer, lr_scheduler scaler, args, lr_scheduler=None):
self.train_loader, self.val_loader, self.test_loader, self.scaler = train_loader, val_loader, test_loader, scaler
log_dir = self.args["log_dir"] self.config = args
self.best_path, self.best_test_path = [os.path.join(log_dir, f"best_{suffix}_model.pth") for suffix in ["", "test"]] self.device = args["basic"]["device"]
self.args = args["train"]
if not self.args["debug"]: os.makedirs(log_dir, exist_ok=True) self.model = model.to(self.device)
self.logger = get_logger(log_dir, name=self.model.__class__.__name__, debug=self.args["debug"]) self.loss = loss
self.logger.info(f"Experiment log path in: {log_dir}") self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
def train(self): self.train_loader = train_loader
best_model = best_test_model = None self.val_loader = val_loader or test_loader
best_loss = best_test_loss = float("inf") self.test_loader = test_loader
not_improved_count = 0 self.scaler = scaler
self.logger.info("Training process started") self._init_paths()
self._init_logger()
for epoch in range(1, self.args["epochs"] + 1): # ---------------- init ----------------
train_loss = self._run_epoch(epoch, self.train_loader, "train") def _init_paths(self):
val_loss = self._run_epoch(epoch, self.val_loader or self.test_loader, "val") d = self.args["log_dir"]
test_loss = self._run_epoch(epoch, self.test_loader, "test") self.best_path = os.path.join(d, "best_model.pth")
self.best_test_path = os.path.join(d, "best_test_model.pth")
if train_loss > 1e6: def _init_logger(self):
self.logger.warning("Gradient explosion detected. Ending...") if not self.args["debug"]:
break os.makedirs(self.args["log_dir"], exist_ok=True)
self.logger = get_logger(
self.args["log_dir"],
name=self.model.__class__.__name__,
debug=self.args["debug"],
)
if val_loss < best_loss: # ---------------- epoch ----------------
best_loss, not_improved_count, best_model = val_loss, 0, copy.deepcopy(self.model.state_dict()) def _run_epoch(self, epoch, loader, mode):
self.logger.info("Best validation model saved!") is_train = mode == "train"
elif self.args["early_stop"] and (not_improved_count := not_improved_count + 1) == self.args["early_stop_patience"]: self.model.train() if is_train else self.model.eval()
self.logger.info(f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
break
if test_loss < best_test_loss: total_loss, start = 0.0, time.time()
best_test_loss, best_test_model = test_loss, copy.deepcopy(self.model.state_dict())
torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
for model_name, state_dict in [("best validation", best_model), ("best test", best_test_model)]:
self.model.load_state_dict(state_dict)
self.logger.info(f"Testing on {model_name} model")
self._run_epoch(None, self.test_loader, "test", log_horizon=True)
def _run_epoch(self, epoch, dataloader, mode, log_horizon=False):
self.model.train() if mode == "train" else self.model.eval()
optimizer_step = mode == "train"
total_loss, epoch_time = 0, time.time()
y_pred, y_true = [], [] y_pred, y_true = [], []
with torch.set_grad_enabled(optimizer_step): with torch.set_grad_enabled(is_train):
for data, target in tqdm(dataloader, total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" if epoch else mode): for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)):
data, target = data.to(self.device), target.to(self.device) data, target = data.to(self.device), target.to(self.device)
label = target[..., :self.args["output_dim"]] label = target[..., :self.args["output_dim"]]
output = self.model(data) out = self.model(data)
loss = self.loss(output, label)
d_output, d_label = self.scaler.inverse_transform(output), self.scaler.inverse_transform(label) if os.environ.get("TRY") == "True":
d_loss = self.loss(d_output, d_label) if out.shape == label.shape:
print(f"shape true, out: {out.shape}, label: {label.shape}")
assert False
else:
print(f"shape false, out: {out.shape}, label: {label.shape}")
assert False
loss = self.loss(out, label)
d_out = self.scaler.inverse_transform(out)
d_lbl = self.scaler.inverse_transform(label)
d_loss = self.loss(d_out, d_lbl)
total_loss += d_loss.item() total_loss += d_loss.item()
y_pred.append(d_output.detach().cpu()) y_pred.append(d_out.detach().cpu())
y_true.append(d_label.detach().cpu()) y_true.append(d_lbl.detach().cpu())
if optimizer_step and self.optimizer: if is_train and self.optimizer:
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss.backward() loss.backward()
if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) if self.args["grad_norm"]:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.args["max_grad_norm"]
)
self.optimizer.step() self.optimizer.step()
y_pred, y_true = torch.cat(y_pred, dim=0), torch.cat(y_true, dim=0) y_pred = torch.cat(y_pred)
y_true = torch.cat(y_true)
if log_horizon: mae, rmse, mape = all_metrics(
for t in range(y_true.shape[1]): y_pred, y_true,
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], self.args["mae_thresh"], self.args["mape_thresh"]) self.args["mae_thresh"],
self.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") self.args["mape_thresh"]
)
avg_mae, avg_rmse, avg_mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) self.logger.info(
f"Epoch #{epoch:02d} {mode:<5} "
f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} "
f"MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s"
)
return total_loss / len(loader)
if epoch and mode: # ---------------- train ----------------
self.logger.info(f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{avg_mae:5.2f} | RMSE:{avg_rmse:5.2f} | MAPE:{avg_mape:7.4f} | Time: {time.time()-epoch_time:.2f} s") def train(self):
elif mode: best, best_test = float("inf"), float("inf")
self.logger.info(f"{mode.capitalize():<5} MAE:{avg_mae:.4f} | RMSE:{avg_rmse:.4f} | MAPE:{avg_mape:.4f}") best_w, best_test_w = None, None
patience = 0
return total_loss / len(dataloader) self.logger.info("Training started")
def test(self, path=None): for epoch in range(1, self.args["epochs"] + 1):
if path: losses = {
self.model.load_state_dict(torch.load(path)["state_dict"]) "train": self._run_epoch(epoch, self.train_loader, "train"),
self.model.to(self.device) "val": self._run_epoch(epoch, self.val_loader, "val"),
"test": self._run_epoch(epoch, self.test_loader, "test"),
}
if losses["train"] > 1e6:
self.logger.warning("Gradient explosion detected")
break
if losses["val"] < best:
best, patience = losses["val"], 0
best_w = copy.deepcopy(self.model.state_dict())
self.logger.info("Best validation model saved")
else:
patience += 1
if self.args["early_stop"] and patience == self.args["early_stop_patience"]:
self.logger.info("Early stopping triggered")
break
if losses["test"] < best_test:
best_test = losses["test"]
best_test_w = copy.deepcopy(self.model.state_dict())
if not self.args["debug"]:
torch.save(best_w, self.best_path)
torch.save(best_test_w, self.best_test_path)
self._final_test(best_w, best_test_w)
# ---------------- final test ----------------
def _final_test(self, best_w, best_test_w):
for name, w in [("best val", best_w), ("best test", best_test_w)]:
self.model.load_state_dict(w)
self.logger.info(f"Testing on {name} model")
self.evaluate()
# ---------------- evaluate ----------------
def evaluate(self):
self.model.eval()
y_pred, y_true = [], []
with torch.no_grad():
for data, target in self.test_loader:
data, target = data.to(self.device), target.to(self.device)
label = target[..., :self.args["output_dim"]]
out = self.model(data)
y_pred.append(out.cpu())
y_true.append(label.cpu())
d_pred = self.scaler.inverse_transform(torch.cat(y_pred))
d_true = self.scaler.inverse_transform(torch.cat(y_true))
for t in range(d_true.shape[1]):
mae, rmse, mape = all_metrics(
d_pred[:, t], d_true[:, t],
self.args["mae_thresh"],
self.args["mape_thresh"]
)
self.logger.info(
f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}"
)
avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"])
self.logger.info(
f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}"
)
self._run_epoch(None, self.test_loader, "test", log_horizon=True)