实现iTransformer
This commit is contained in:
parent
e4a7884c98
commit
4984d24506
|
|
@ -2089,6 +2089,14 @@
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": "--config ./config/STGCN/PEMSD7.yaml"
|
"args": "--config ./config/STGCN/PEMSD7.yaml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "iTransformer: METR-LA",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": "--config ./config/iTransformer/METR-LA.yaml"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
basic:
|
||||||
|
dataset: METR-LA
|
||||||
|
device: cuda:0
|
||||||
|
mode: train
|
||||||
|
model: iTransformer
|
||||||
|
seed: 2023
|
||||||
|
|
||||||
|
data:
|
||||||
|
batch_size: 16
|
||||||
|
column_wise: false
|
||||||
|
days_per_week: 7
|
||||||
|
horizon: 24
|
||||||
|
input_dim: 1
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 207
|
||||||
|
steps_per_day: 288
|
||||||
|
test_ratio: 0.2
|
||||||
|
val_ratio: 0.2
|
||||||
|
|
||||||
|
model:
|
||||||
|
activation: gelu
|
||||||
|
seq_len: 24
|
||||||
|
pred_len: 24
|
||||||
|
d_model: 128
|
||||||
|
d_ff: 2048
|
||||||
|
dropout: 0.1
|
||||||
|
e_layers: 2
|
||||||
|
n_heads: 8
|
||||||
|
output_attention: False
|
||||||
|
|
||||||
|
|
||||||
|
train:
|
||||||
|
batch_size: 16
|
||||||
|
debug: false
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
log_step: 1000
|
||||||
|
loss_func: mae
|
||||||
|
lr_decay: true
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: 5,20,40,70
|
||||||
|
lr_init: 0.003
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
max_grad_norm: 5
|
||||||
|
output_dim: 1
|
||||||
|
plot: false
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
from utils.normalization import normalize_dataset
|
|
||||||
from dataloader.data_selector import load_st_dataset
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from dataloader.data_selector import load_st_dataset
|
||||||
|
from utils.normalization import normalize_dataset
|
||||||
|
|
||||||
|
|
||||||
def get_dataloader(args, normalizer="std", single=True):
|
def get_dataloader(args, normalizer="std", single=True):
|
||||||
data = load_st_dataset(args)
|
data = load_st_dataset(args)
|
||||||
|
|
@ -152,7 +152,7 @@ def add_window_y(data, window=3, horizon=1, single=False):
|
||||||
offset = window if not single else window + horizon - 1
|
offset = window if not single else window + horizon - 1
|
||||||
return _generate_windows(data, window=1 if single else horizon, horizon=horizon, offset=offset)
|
return _generate_windows(data, window=1 if single else horizon, horizon=horizon, offset=offset)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
from dataloader.data_selector import load_st_dataset
|
# from dataloader.data_selector import load_st_dataset
|
||||||
res = load_st_dataset({"dataset": "SD"})
|
# res = load_st_dataset({"dataset": "SD"})
|
||||||
print(f"Dataset shape: {res.shape}")
|
# print(f"Dataset shape: {res.shape}")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,216 @@
|
||||||
|
from dataloader.data_selector import load_st_dataset
|
||||||
|
from utils.normalization import normalize_dataset
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataloader(args, normalizer="std", single=True):
|
||||||
|
data = load_st_dataset(args)
|
||||||
|
|
||||||
|
args = args["data"]
|
||||||
|
L, N, F = data.shape
|
||||||
|
data = data.reshape(L, N*F) # [L, N*F]
|
||||||
|
|
||||||
|
# Generate sliding windows for main data and add time features
|
||||||
|
x, y = _prepare_data_with_windows(data, args, single)
|
||||||
|
|
||||||
|
# Split data
|
||||||
|
split_fn = split_data_by_days if args["test_ratio"] > 1 else split_data_by_ratio
|
||||||
|
x_train, x_val, x_test = split_fn(x, args["val_ratio"], args["test_ratio"])
|
||||||
|
y_train, y_val, y_test = split_fn(y, args["val_ratio"], args["test_ratio"])
|
||||||
|
|
||||||
|
# Normalize x and y using the same scaler
|
||||||
|
scaler = _normalize_data(x_train, x_val, x_test, args, normalizer)
|
||||||
|
_apply_existing_scaler(y_train, y_val, y_test, scaler, args)
|
||||||
|
|
||||||
|
# Create dataloaders
|
||||||
|
return (
|
||||||
|
_create_dataloader(x_train, y_train, args["batch_size"], True, False),
|
||||||
|
_create_dataloader(x_val, y_val, args["batch_size"], False, False),
|
||||||
|
_create_dataloader(x_test, y_test, args["batch_size"], False, False),
|
||||||
|
scaler
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_data_with_windows(data, args, single):
|
||||||
|
# Generate sliding windows for main data
|
||||||
|
x = add_window_x(data, args["lag"], args["horizon"], single)
|
||||||
|
y = add_window_y(data, args["lag"], args["horizon"], single)
|
||||||
|
|
||||||
|
# Generate time features
|
||||||
|
time_features = _generate_time_features(data.shape[0], args)
|
||||||
|
|
||||||
|
# Add time features to x and y
|
||||||
|
x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x)
|
||||||
|
y = _add_time_features(y, time_features, args["lag"], args["horizon"], single, add_window_y)
|
||||||
|
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_time_features(L, args):
|
||||||
|
# For time series data, we generate time features for each time step
|
||||||
|
# [L, 1] -> [L, T, 1] by repeating across time dimension
|
||||||
|
T = args.get("time_dim", 1) # Get time dimension size if available
|
||||||
|
|
||||||
|
time_in_day = [i % args["steps_per_day"] / args["steps_per_day"] for i in range(L)]
|
||||||
|
time_in_day = np.array(time_in_day)[:, None, None] # [L, 1, 1]
|
||||||
|
time_in_day = np.tile(time_in_day, (1, T, 1)) # [L, T, 1]
|
||||||
|
|
||||||
|
day_in_week = [(i // args["steps_per_day"]) % args["days_per_week"] for i in range(L)]
|
||||||
|
day_in_week = np.array(day_in_week)[:, None, None] # [L, 1, 1]
|
||||||
|
day_in_week = np.tile(day_in_week, (1, T, 1)) # [L, T, 1]
|
||||||
|
|
||||||
|
return time_in_day, day_in_week
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _add_time_features(data, time_features, lag, horizon, single, window_fn):
|
||||||
|
time_in_day, day_in_week = time_features
|
||||||
|
time_day = window_fn(time_in_day, lag, horizon, single)
|
||||||
|
time_week = window_fn(day_in_week, lag, horizon, single)
|
||||||
|
return np.concatenate([data, time_day, time_week], axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_data(train_data, val_data, test_data, args, normalizer):
|
||||||
|
scaler = normalize_dataset(train_data[..., : args["input_dim"]], normalizer, args["column_wise"])
|
||||||
|
|
||||||
|
for data in [train_data, val_data, test_data]:
|
||||||
|
data[..., : args["input_dim"]] = scaler.transform(data[..., : args["input_dim"]])
|
||||||
|
|
||||||
|
return scaler
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_existing_scaler(train_data, val_data, test_data, scaler, args):
|
||||||
|
for data in [train_data, val_data, test_data]:
|
||||||
|
data[..., : args["input_dim"]] = scaler.transform(data[..., : args["input_dim"]])
|
||||||
|
|
||||||
|
|
||||||
|
def _create_dataloader(X_data, Y_data, batch_size, shuffle, drop_last):
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
X_tensor = torch.tensor(X_data, dtype=torch.float32, device=device)
|
||||||
|
Y_tensor = torch.tensor(Y_data, dtype=torch.float32, device=device)
|
||||||
|
dataset = torch.utils.data.TensorDataset(X_tensor, Y_tensor)
|
||||||
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||||
|
|
||||||
|
|
||||||
|
def split_data_by_days(data, val_days, test_days, interval=30):
|
||||||
|
t = int((24 * 60) / interval)
|
||||||
|
test_data = data[-t * int(test_days) :]
|
||||||
|
val_data = data[-t * int(test_days + val_days) : -t * int(test_days)]
|
||||||
|
train_data = data[: -t * int(test_days + val_days)]
|
||||||
|
return train_data, val_data, test_data
|
||||||
|
|
||||||
|
|
||||||
|
def split_data_by_ratio(data, val_ratio, test_ratio):
|
||||||
|
data_len = data.shape[0]
|
||||||
|
test_data = data[-int(data_len * test_ratio) :]
|
||||||
|
val_data = data[
|
||||||
|
-int(data_len * (test_ratio + val_ratio)) : -int(data_len * test_ratio)
|
||||||
|
]
|
||||||
|
train_data = data[: -int(data_len * (test_ratio + val_ratio))]
|
||||||
|
return train_data, val_data, test_data
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_windows(data, window=3, horizon=1, offset=0):
|
||||||
|
"""
|
||||||
|
Internal helper function to generate sliding windows.
|
||||||
|
|
||||||
|
:param data: Input data, shape [L, T, C]
|
||||||
|
:param window: Window size
|
||||||
|
:param horizon: Horizon size
|
||||||
|
:param offset: Offset from window start
|
||||||
|
:return: Windowed data, shape [num_windows, window, T, C]
|
||||||
|
"""
|
||||||
|
length = len(data)
|
||||||
|
end_index = length - horizon - window + 1
|
||||||
|
windows = []
|
||||||
|
index = 0
|
||||||
|
|
||||||
|
if end_index <= 0:
|
||||||
|
raise ValueError(f"end_index is non-positive: {end_index}, length={length}, horizon={horizon}, window={window}")
|
||||||
|
|
||||||
|
while index < end_index:
|
||||||
|
window_data = data[index + offset : index + offset + window]
|
||||||
|
windows.append(window_data)
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
if not windows:
|
||||||
|
raise ValueError("No windows generated")
|
||||||
|
|
||||||
|
# Check window shapes
|
||||||
|
first_shape = windows[0].shape
|
||||||
|
for i, w in enumerate(windows):
|
||||||
|
if w.shape != first_shape:
|
||||||
|
raise ValueError(f"Window {i} has shape {w.shape}, expected {first_shape}")
|
||||||
|
|
||||||
|
return np.array(windows)
|
||||||
|
|
||||||
|
def add_window_x(data, window=3, horizon=1, single=False):
|
||||||
|
"""
|
||||||
|
Generate windowed X values from the input data.
|
||||||
|
|
||||||
|
:param data: Input data, shape [L, T, C]
|
||||||
|
:param window: Size of the sliding window
|
||||||
|
:param horizon: Horizon size
|
||||||
|
:param single: If True, generate single-step windows, else multi-step
|
||||||
|
:return: X with shape [num_windows, window, T, C]
|
||||||
|
"""
|
||||||
|
return _generate_windows(data, window, horizon, offset=0)
|
||||||
|
|
||||||
|
def add_window_y(data, window=3, horizon=1, single=False):
|
||||||
|
"""
|
||||||
|
Generate windowed Y values from the input data.
|
||||||
|
|
||||||
|
:param data: Input data, shape [L, T, C]
|
||||||
|
:param window: Size of the sliding window
|
||||||
|
:param horizon: Horizon size
|
||||||
|
:param single: If True, generate single-step windows, else multi-step
|
||||||
|
:return: Y with shape [num_windows, horizon, T, C]
|
||||||
|
"""
|
||||||
|
offset = window if not single else window + horizon - 1
|
||||||
|
return _generate_windows(data, window=1 if single else horizon, horizon=horizon, offset=offset)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# Test with a dummy config using METR-LA dataset
|
||||||
|
dummy_args = {
|
||||||
|
"basic": {
|
||||||
|
"dataset": "METR-LA"
|
||||||
|
},
|
||||||
|
"data": {
|
||||||
|
"lag": 3,
|
||||||
|
"horizon": 1,
|
||||||
|
"val_ratio": 0.1,
|
||||||
|
"test_ratio": 0.2,
|
||||||
|
"steps_per_day": 288,
|
||||||
|
"days_per_week": 7,
|
||||||
|
"input_dim": 1,
|
||||||
|
"column_wise": False,
|
||||||
|
"batch_size": 32,
|
||||||
|
"time_dim": 1 # Add time dimension parameter
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load data
|
||||||
|
data = load_st_dataset(dummy_args)
|
||||||
|
print(f"Original data shape: {data.shape}")
|
||||||
|
|
||||||
|
# Get dataloader
|
||||||
|
train_loader, val_loader, test_loader, scaler = get_dataloader(dummy_args)
|
||||||
|
|
||||||
|
# Test data loader
|
||||||
|
for batch_x, batch_y in train_loader:
|
||||||
|
print(f"Batch X shape: {batch_x.shape}")
|
||||||
|
print(f"Batch Y shape: {batch_y.shape}")
|
||||||
|
break
|
||||||
|
|
||||||
|
print("Test passed successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Test failed with error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
@ -3,6 +3,7 @@ from dataloader.PeMSDdataloader import get_dataloader as normal_loader
|
||||||
from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader
|
from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader
|
||||||
from dataloader.EXPdataloader import get_dataloader as EXP_loader
|
from dataloader.EXPdataloader import get_dataloader as EXP_loader
|
||||||
from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader
|
from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader
|
||||||
|
from dataloader.TSloader import get_dataloader as TS_loader
|
||||||
|
|
||||||
|
|
||||||
def get_dataloader(config, normalizer, single):
|
def get_dataloader(config, normalizer, single):
|
||||||
|
|
@ -16,5 +17,7 @@ def get_dataloader(config, normalizer, single):
|
||||||
return DCRNN_loader(config, normalizer, single)
|
return DCRNN_loader(config, normalizer, single)
|
||||||
case "EXP":
|
case "EXP":
|
||||||
return EXP_loader(config, normalizer, single)
|
return EXP_loader(config, normalizer, single)
|
||||||
|
case "iTransformer":
|
||||||
|
return TS_loader(config, normalizer, single)
|
||||||
case _:
|
case _:
|
||||||
return normal_loader(config, normalizer, single)
|
return normal_loader(config, normalizer, single)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from model.iTransformer.layers.Transformer_EncDec import Encoder, EncoderLayer
|
||||||
|
from model.iTransformer.layers.SelfAttn import FullAttention, AttentionLayer
|
||||||
|
from model.iTransformer.layers.Embed import DataEmbedding_inverted
|
||||||
|
|
||||||
|
class iTransformer(nn.Module):
|
||||||
|
"""
|
||||||
|
Paper link: https://arxiv.org/abs/2310.06625
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args):
|
||||||
|
super(iTransformer, self).__init__()
|
||||||
|
self.pred_len = args['pred_len']
|
||||||
|
# Embedding
|
||||||
|
self.enc_embedding = DataEmbedding_inverted(args['seq_len'], args['d_model'], args['dropout'])
|
||||||
|
# Encoder-only architecture
|
||||||
|
self.encoder = Encoder(
|
||||||
|
[
|
||||||
|
EncoderLayer(
|
||||||
|
AttentionLayer(
|
||||||
|
FullAttention(False, attention_dropout=args['dropout'],
|
||||||
|
output_attention=args['output_attention']), args['d_model'], args['n_heads']),
|
||||||
|
args['d_model'],
|
||||||
|
args['d_ff'],
|
||||||
|
dropout=args['dropout'],
|
||||||
|
activation=args['activation']
|
||||||
|
) for l in range(args['e_layers'])
|
||||||
|
],
|
||||||
|
norm_layer=torch.nn.LayerNorm(args['d_model'])
|
||||||
|
)
|
||||||
|
self.projector = nn.Linear(args['d_model'], args['pred_len'], bias=True)
|
||||||
|
|
||||||
|
def forecast(self, x_enc, x_mark_enc):
|
||||||
|
_, _, N = x_enc.shape # B, T, C
|
||||||
|
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||||
|
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||||
|
dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates
|
||||||
|
return dec_out, attns
|
||||||
|
|
||||||
|
def forward(self, x_enc, x_mark_enc):
|
||||||
|
dec_out, attns = self.forecast(x_enc, x_mark_enc)
|
||||||
|
return dec_out[:, -self.pred_len:, :] # [B, T, C]
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class DataEmbedding_inverted(nn.Module):
|
||||||
|
def __init__(self, c_in, d_model, dropout=0.1):
|
||||||
|
super(DataEmbedding_inverted, self).__init__()
|
||||||
|
self.value_embedding = nn.Linear(c_in, d_model)
|
||||||
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
|
||||||
|
def forward(self, x, x_mark):
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
# x: [Batch Variate Time]
|
||||||
|
if x_mark is None:
|
||||||
|
x = self.value_embedding(x)
|
||||||
|
else:
|
||||||
|
# the potential to take covariates (e.g. timestamps) as tokens
|
||||||
|
x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
|
||||||
|
# x: [Batch Variate d_model]
|
||||||
|
return self.dropout(x)
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from math import sqrt
|
||||||
|
|
||||||
|
|
||||||
|
class FullAttention(nn.Module):
|
||||||
|
def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False):
|
||||||
|
super(FullAttention, self).__init__()
|
||||||
|
self.scale = scale
|
||||||
|
self.mask_flag = mask_flag
|
||||||
|
self.output_attention = output_attention
|
||||||
|
self.dropout = nn.Dropout(attention_dropout)
|
||||||
|
|
||||||
|
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
||||||
|
B, L, H, E = queries.shape
|
||||||
|
_, S, _, D = values.shape
|
||||||
|
scale = self.scale or 1. / sqrt(E)
|
||||||
|
|
||||||
|
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
|
||||||
|
|
||||||
|
if self.mask_flag:
|
||||||
|
if attn_mask is None:
|
||||||
|
attn_mask = TriangularCausalMask(B, L, device=queries.device)
|
||||||
|
|
||||||
|
scores.masked_fill_(attn_mask.mask, -np.inf)
|
||||||
|
|
||||||
|
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
||||||
|
V = torch.einsum("bhls,bshd->blhd", A, values)
|
||||||
|
|
||||||
|
if self.output_attention:
|
||||||
|
return V.contiguous(), A
|
||||||
|
else:
|
||||||
|
return V.contiguous(), None
|
||||||
|
|
||||||
|
class AttentionLayer(nn.Module):
|
||||||
|
def __init__(self, attention, d_model, n_heads, d_keys=None,
|
||||||
|
d_values=None):
|
||||||
|
super(AttentionLayer, self).__init__()
|
||||||
|
|
||||||
|
d_keys = d_keys or (d_model // n_heads)
|
||||||
|
d_values = d_values or (d_model // n_heads)
|
||||||
|
|
||||||
|
self.inner_attention = attention
|
||||||
|
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
||||||
|
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
||||||
|
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
||||||
|
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
||||||
|
self.n_heads = n_heads
|
||||||
|
|
||||||
|
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
||||||
|
B, L, _ = queries.shape
|
||||||
|
_, S, _ = keys.shape
|
||||||
|
H = self.n_heads
|
||||||
|
|
||||||
|
queries = self.query_projection(queries).view(B, L, H, -1)
|
||||||
|
keys = self.key_projection(keys).view(B, S, H, -1)
|
||||||
|
values = self.value_projection(values).view(B, S, H, -1)
|
||||||
|
|
||||||
|
out, attn = self.inner_attention(
|
||||||
|
queries,
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
attn_mask,
|
||||||
|
tau=tau,
|
||||||
|
delta=delta
|
||||||
|
)
|
||||||
|
out = out.view(B, L, -1)
|
||||||
|
|
||||||
|
return self.out_projection(out), attn
|
||||||
|
|
||||||
|
|
||||||
|
class TriangularCausalMask:
|
||||||
|
def __init__(self, B, L, device="cpu"):
|
||||||
|
mask_shape = [B, 1, L, L]
|
||||||
|
with torch.no_grad():
|
||||||
|
self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mask(self):
|
||||||
|
return self._mask
|
||||||
|
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class EncoderLayer(nn.Module):
|
||||||
|
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
|
||||||
|
super(EncoderLayer, self).__init__()
|
||||||
|
d_ff = d_ff or 4 * d_model
|
||||||
|
self.attention = attention
|
||||||
|
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
||||||
|
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
||||||
|
self.norm1 = nn.LayerNorm(d_model)
|
||||||
|
self.norm2 = nn.LayerNorm(d_model)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.activation = F.relu if activation == "relu" else F.gelu
|
||||||
|
|
||||||
|
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
||||||
|
new_x, attn = self.attention(
|
||||||
|
x, x, x,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
tau=tau, delta=delta
|
||||||
|
)
|
||||||
|
x = x + self.dropout(new_x)
|
||||||
|
|
||||||
|
y = x = self.norm1(x)
|
||||||
|
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||||||
|
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||||||
|
|
||||||
|
return self.norm2(x + y), attn
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
|
||||||
|
super(Encoder, self).__init__()
|
||||||
|
self.attn_layers = nn.ModuleList(attn_layers)
|
||||||
|
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
|
||||||
|
self.norm = norm_layer
|
||||||
|
|
||||||
|
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
||||||
|
# x [B, L, D]
|
||||||
|
attns = []
|
||||||
|
if self.conv_layers is not None:
|
||||||
|
for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
|
||||||
|
delta = delta if i == 0 else None
|
||||||
|
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
||||||
|
x = conv_layer(x)
|
||||||
|
attns.append(attn)
|
||||||
|
x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
|
||||||
|
attns.append(attn)
|
||||||
|
else:
|
||||||
|
for attn_layer in self.attn_layers:
|
||||||
|
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
||||||
|
attns.append(attn)
|
||||||
|
|
||||||
|
if self.norm is not None:
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x, attns
|
||||||
|
|
@ -26,6 +26,7 @@ from model.REPST.repst import repst as REPST
|
||||||
from model.ASTRA.astra import ASTRA as ASTRA
|
from model.ASTRA.astra import ASTRA as ASTRA
|
||||||
from model.ASTRA.astrav2 import ASTRA as ASTRAv2
|
from model.ASTRA.astrav2 import ASTRA as ASTRAv2
|
||||||
from model.ASTRA.astrav3 import ASTRA as ASTRAv3
|
from model.ASTRA.astrav3 import ASTRA as ASTRAv3
|
||||||
|
from model.iTransformer.iTransformer import iTransformer
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -89,3 +90,5 @@ def model_selector(config):
|
||||||
return ASTRAv2(model_config)
|
return ASTRAv2(model_config)
|
||||||
case "ASTRA_v3":
|
case "ASTRA_v3":
|
||||||
return ASTRAv3(model_config)
|
return ASTRAv3(model_config)
|
||||||
|
case "iTransformer":
|
||||||
|
return iTransformer(model_config)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
numpy
|
||||||
pyyaml
|
pyyaml
|
||||||
tqdm
|
tqdm
|
||||||
statsmodels
|
statsmodels
|
||||||
|
|
|
||||||
26
run.py
26
run.py
|
|
@ -11,36 +11,28 @@ from trainer.trainer_selector import select_trainer
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# 读取配置
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
# 初始化 device, seed, model, data, trainer
|
||||||
args = init.init_device(args)
|
args = init.init_device(args)
|
||||||
init.init_seed(args["basic"]["seed"])
|
init.init_seed(args["basic"]["seed"])
|
||||||
|
|
||||||
# Load model
|
|
||||||
model = init.init_model(args)
|
model = init.init_model(args)
|
||||||
|
|
||||||
# Load dataset
|
|
||||||
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
|
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
|
||||||
args, normalizer=args["data"]["normalizer"], single=False
|
args, normalizer=args["data"]["normalizer"], single=False
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = init.init_loss(args, scaler)
|
loss = init.init_loss(args, scaler)
|
||||||
optimizer, lr_scheduler = init.init_optimizer(model, args["train"])
|
optimizer, lr_scheduler = init.init_optimizer(model, args["train"])
|
||||||
init.create_logs(args)
|
init.create_logs(args)
|
||||||
|
|
||||||
# Start training or testing
|
|
||||||
trainer = select_trainer(
|
trainer = select_trainer(
|
||||||
model,
|
model,
|
||||||
loss,
|
loss, optimizer,
|
||||||
optimizer,
|
train_loader, val_loader, test_loader, scaler,
|
||||||
train_loader,
|
|
||||||
val_loader,
|
|
||||||
test_loader,
|
|
||||||
scaler,
|
|
||||||
args,
|
args,
|
||||||
lr_scheduler,
|
lr_scheduler, extra_data,
|
||||||
extra_data,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 开始训练
|
||||||
match args["basic"]["mode"]:
|
match args["basic"]["mode"]:
|
||||||
case "train":
|
case "train":
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
@ -54,9 +46,7 @@ def main():
|
||||||
)
|
)
|
||||||
trainer.test(
|
trainer.test(
|
||||||
model.to(args["basic"]["device"]),
|
model.to(args["basic"]["device"]),
|
||||||
trainer.args,
|
trainer.args, test_loader, scaler,
|
||||||
test_loader,
|
|
||||||
scaler,
|
|
||||||
trainer.logger,
|
trainer.logger,
|
||||||
)
|
)
|
||||||
case _:
|
case _:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue