实现iTransformer
This commit is contained in:
parent
e4a7884c98
commit
4984d24506
|
|
@ -2089,6 +2089,14 @@
|
|||
"program": "run.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": "--config ./config/STGCN/PEMSD7.yaml"
|
||||
},
|
||||
{
|
||||
"name": "iTransformer: METR-LA",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "run.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": "--config ./config/iTransformer/METR-LA.yaml"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
basic:
|
||||
dataset: METR-LA
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: iTransformer
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 207
|
||||
steps_per_day: 288
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
activation: gelu
|
||||
seq_len: 24
|
||||
pred_len: 24
|
||||
d_model: 128
|
||||
d_ff: 2048
|
||||
dropout: 0.1
|
||||
e_layers: 2
|
||||
n_heads: 8
|
||||
output_attention: False
|
||||
|
||||
|
||||
train:
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_step: 1000
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
plot: false
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
from utils.normalization import normalize_dataset
|
||||
from dataloader.data_selector import load_st_dataset
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from dataloader.data_selector import load_st_dataset
|
||||
from utils.normalization import normalize_dataset
|
||||
|
||||
|
||||
def get_dataloader(args, normalizer="std", single=True):
|
||||
data = load_st_dataset(args)
|
||||
|
|
@ -152,7 +152,7 @@ def add_window_y(data, window=3, horizon=1, single=False):
|
|||
offset = window if not single else window + horizon - 1
|
||||
return _generate_windows(data, window=1 if single else horizon, horizon=horizon, offset=offset)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dataloader.data_selector import load_st_dataset
|
||||
res = load_st_dataset({"dataset": "SD"})
|
||||
print(f"Dataset shape: {res.shape}")
|
||||
# if __name__ == "__main__":
|
||||
# from dataloader.data_selector import load_st_dataset
|
||||
# res = load_st_dataset({"dataset": "SD"})
|
||||
# print(f"Dataset shape: {res.shape}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,216 @@
|
|||
from dataloader.data_selector import load_st_dataset
|
||||
from utils.normalization import normalize_dataset
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def get_dataloader(args, normalizer="std", single=True):
|
||||
data = load_st_dataset(args)
|
||||
|
||||
args = args["data"]
|
||||
L, N, F = data.shape
|
||||
data = data.reshape(L, N*F) # [L, N*F]
|
||||
|
||||
# Generate sliding windows for main data and add time features
|
||||
x, y = _prepare_data_with_windows(data, args, single)
|
||||
|
||||
# Split data
|
||||
split_fn = split_data_by_days if args["test_ratio"] > 1 else split_data_by_ratio
|
||||
x_train, x_val, x_test = split_fn(x, args["val_ratio"], args["test_ratio"])
|
||||
y_train, y_val, y_test = split_fn(y, args["val_ratio"], args["test_ratio"])
|
||||
|
||||
# Normalize x and y using the same scaler
|
||||
scaler = _normalize_data(x_train, x_val, x_test, args, normalizer)
|
||||
_apply_existing_scaler(y_train, y_val, y_test, scaler, args)
|
||||
|
||||
# Create dataloaders
|
||||
return (
|
||||
_create_dataloader(x_train, y_train, args["batch_size"], True, False),
|
||||
_create_dataloader(x_val, y_val, args["batch_size"], False, False),
|
||||
_create_dataloader(x_test, y_test, args["batch_size"], False, False),
|
||||
scaler
|
||||
)
|
||||
|
||||
|
||||
def _prepare_data_with_windows(data, args, single):
|
||||
# Generate sliding windows for main data
|
||||
x = add_window_x(data, args["lag"], args["horizon"], single)
|
||||
y = add_window_y(data, args["lag"], args["horizon"], single)
|
||||
|
||||
# Generate time features
|
||||
time_features = _generate_time_features(data.shape[0], args)
|
||||
|
||||
# Add time features to x and y
|
||||
x = _add_time_features(x, time_features, args["lag"], args["horizon"], single, add_window_x)
|
||||
y = _add_time_features(y, time_features, args["lag"], args["horizon"], single, add_window_y)
|
||||
|
||||
return x, y
|
||||
|
||||
|
||||
def _generate_time_features(L, args):
|
||||
# For time series data, we generate time features for each time step
|
||||
# [L, 1] -> [L, T, 1] by repeating across time dimension
|
||||
T = args.get("time_dim", 1) # Get time dimension size if available
|
||||
|
||||
time_in_day = [i % args["steps_per_day"] / args["steps_per_day"] for i in range(L)]
|
||||
time_in_day = np.array(time_in_day)[:, None, None] # [L, 1, 1]
|
||||
time_in_day = np.tile(time_in_day, (1, T, 1)) # [L, T, 1]
|
||||
|
||||
day_in_week = [(i // args["steps_per_day"]) % args["days_per_week"] for i in range(L)]
|
||||
day_in_week = np.array(day_in_week)[:, None, None] # [L, 1, 1]
|
||||
day_in_week = np.tile(day_in_week, (1, T, 1)) # [L, T, 1]
|
||||
|
||||
return time_in_day, day_in_week
|
||||
|
||||
|
||||
|
||||
def _add_time_features(data, time_features, lag, horizon, single, window_fn):
|
||||
time_in_day, day_in_week = time_features
|
||||
time_day = window_fn(time_in_day, lag, horizon, single)
|
||||
time_week = window_fn(day_in_week, lag, horizon, single)
|
||||
return np.concatenate([data, time_day, time_week], axis=-1)
|
||||
|
||||
|
||||
def _normalize_data(train_data, val_data, test_data, args, normalizer):
|
||||
scaler = normalize_dataset(train_data[..., : args["input_dim"]], normalizer, args["column_wise"])
|
||||
|
||||
for data in [train_data, val_data, test_data]:
|
||||
data[..., : args["input_dim"]] = scaler.transform(data[..., : args["input_dim"]])
|
||||
|
||||
return scaler
|
||||
|
||||
|
||||
def _apply_existing_scaler(train_data, val_data, test_data, scaler, args):
|
||||
for data in [train_data, val_data, test_data]:
|
||||
data[..., : args["input_dim"]] = scaler.transform(data[..., : args["input_dim"]])
|
||||
|
||||
|
||||
def _create_dataloader(X_data, Y_data, batch_size, shuffle, drop_last):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
X_tensor = torch.tensor(X_data, dtype=torch.float32, device=device)
|
||||
Y_tensor = torch.tensor(Y_data, dtype=torch.float32, device=device)
|
||||
dataset = torch.utils.data.TensorDataset(X_tensor, Y_tensor)
|
||||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
|
||||
|
||||
def split_data_by_days(data, val_days, test_days, interval=30):
|
||||
t = int((24 * 60) / interval)
|
||||
test_data = data[-t * int(test_days) :]
|
||||
val_data = data[-t * int(test_days + val_days) : -t * int(test_days)]
|
||||
train_data = data[: -t * int(test_days + val_days)]
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
def split_data_by_ratio(data, val_ratio, test_ratio):
|
||||
data_len = data.shape[0]
|
||||
test_data = data[-int(data_len * test_ratio) :]
|
||||
val_data = data[
|
||||
-int(data_len * (test_ratio + val_ratio)) : -int(data_len * test_ratio)
|
||||
]
|
||||
train_data = data[: -int(data_len * (test_ratio + val_ratio))]
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
|
||||
|
||||
def _generate_windows(data, window=3, horizon=1, offset=0):
|
||||
"""
|
||||
Internal helper function to generate sliding windows.
|
||||
|
||||
:param data: Input data, shape [L, T, C]
|
||||
:param window: Window size
|
||||
:param horizon: Horizon size
|
||||
:param offset: Offset from window start
|
||||
:return: Windowed data, shape [num_windows, window, T, C]
|
||||
"""
|
||||
length = len(data)
|
||||
end_index = length - horizon - window + 1
|
||||
windows = []
|
||||
index = 0
|
||||
|
||||
if end_index <= 0:
|
||||
raise ValueError(f"end_index is non-positive: {end_index}, length={length}, horizon={horizon}, window={window}")
|
||||
|
||||
while index < end_index:
|
||||
window_data = data[index + offset : index + offset + window]
|
||||
windows.append(window_data)
|
||||
index += 1
|
||||
|
||||
if not windows:
|
||||
raise ValueError("No windows generated")
|
||||
|
||||
# Check window shapes
|
||||
first_shape = windows[0].shape
|
||||
for i, w in enumerate(windows):
|
||||
if w.shape != first_shape:
|
||||
raise ValueError(f"Window {i} has shape {w.shape}, expected {first_shape}")
|
||||
|
||||
return np.array(windows)
|
||||
|
||||
def add_window_x(data, window=3, horizon=1, single=False):
|
||||
"""
|
||||
Generate windowed X values from the input data.
|
||||
|
||||
:param data: Input data, shape [L, T, C]
|
||||
:param window: Size of the sliding window
|
||||
:param horizon: Horizon size
|
||||
:param single: If True, generate single-step windows, else multi-step
|
||||
:return: X with shape [num_windows, window, T, C]
|
||||
"""
|
||||
return _generate_windows(data, window, horizon, offset=0)
|
||||
|
||||
def add_window_y(data, window=3, horizon=1, single=False):
|
||||
"""
|
||||
Generate windowed Y values from the input data.
|
||||
|
||||
:param data: Input data, shape [L, T, C]
|
||||
:param window: Size of the sliding window
|
||||
:param horizon: Horizon size
|
||||
:param single: If True, generate single-step windows, else multi-step
|
||||
:return: Y with shape [num_windows, horizon, T, C]
|
||||
"""
|
||||
offset = window if not single else window + horizon - 1
|
||||
return _generate_windows(data, window=1 if single else horizon, horizon=horizon, offset=offset)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Test with a dummy config using METR-LA dataset
|
||||
dummy_args = {
|
||||
"basic": {
|
||||
"dataset": "METR-LA"
|
||||
},
|
||||
"data": {
|
||||
"lag": 3,
|
||||
"horizon": 1,
|
||||
"val_ratio": 0.1,
|
||||
"test_ratio": 0.2,
|
||||
"steps_per_day": 288,
|
||||
"days_per_week": 7,
|
||||
"input_dim": 1,
|
||||
"column_wise": False,
|
||||
"batch_size": 32,
|
||||
"time_dim": 1 # Add time dimension parameter
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
# Load data
|
||||
data = load_st_dataset(dummy_args)
|
||||
print(f"Original data shape: {data.shape}")
|
||||
|
||||
# Get dataloader
|
||||
train_loader, val_loader, test_loader, scaler = get_dataloader(dummy_args)
|
||||
|
||||
# Test data loader
|
||||
for batch_x, batch_y in train_loader:
|
||||
print(f"Batch X shape: {batch_x.shape}")
|
||||
print(f"Batch Y shape: {batch_y.shape}")
|
||||
break
|
||||
|
||||
print("Test passed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Test failed with error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
|
@ -3,6 +3,7 @@ from dataloader.PeMSDdataloader import get_dataloader as normal_loader
|
|||
from dataloader.DCRNNdataloader import get_dataloader as DCRNN_loader
|
||||
from dataloader.EXPdataloader import get_dataloader as EXP_loader
|
||||
from dataloader.cde_loader.cdeDataloader import get_dataloader as nrde_loader
|
||||
from dataloader.TSloader import get_dataloader as TS_loader
|
||||
|
||||
|
||||
def get_dataloader(config, normalizer, single):
|
||||
|
|
@ -16,5 +17,7 @@ def get_dataloader(config, normalizer, single):
|
|||
return DCRNN_loader(config, normalizer, single)
|
||||
case "EXP":
|
||||
return EXP_loader(config, normalizer, single)
|
||||
case "iTransformer":
|
||||
return TS_loader(config, normalizer, single)
|
||||
case _:
|
||||
return normal_loader(config, normalizer, single)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,43 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from model.iTransformer.layers.Transformer_EncDec import Encoder, EncoderLayer
|
||||
from model.iTransformer.layers.SelfAttn import FullAttention, AttentionLayer
|
||||
from model.iTransformer.layers.Embed import DataEmbedding_inverted
|
||||
|
||||
class iTransformer(nn.Module):
|
||||
"""
|
||||
Paper link: https://arxiv.org/abs/2310.06625
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
super(iTransformer, self).__init__()
|
||||
self.pred_len = args['pred_len']
|
||||
# Embedding
|
||||
self.enc_embedding = DataEmbedding_inverted(args['seq_len'], args['d_model'], args['dropout'])
|
||||
# Encoder-only architecture
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AttentionLayer(
|
||||
FullAttention(False, attention_dropout=args['dropout'],
|
||||
output_attention=args['output_attention']), args['d_model'], args['n_heads']),
|
||||
args['d_model'],
|
||||
args['d_ff'],
|
||||
dropout=args['dropout'],
|
||||
activation=args['activation']
|
||||
) for l in range(args['e_layers'])
|
||||
],
|
||||
norm_layer=torch.nn.LayerNorm(args['d_model'])
|
||||
)
|
||||
self.projector = nn.Linear(args['d_model'], args['pred_len'], bias=True)
|
||||
|
||||
def forecast(self, x_enc, x_mark_enc):
|
||||
_, _, N = x_enc.shape # B, T, C
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=None)
|
||||
dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates
|
||||
return dec_out, attns
|
||||
|
||||
def forward(self, x_enc, x_mark_enc):
|
||||
dec_out, attns = self.forecast(x_enc, x_mark_enc)
|
||||
return dec_out[:, -self.pred_len:, :] # [B, T, C]
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class DataEmbedding_inverted(nn.Module):
|
||||
def __init__(self, c_in, d_model, dropout=0.1):
|
||||
super(DataEmbedding_inverted, self).__init__()
|
||||
self.value_embedding = nn.Linear(c_in, d_model)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
x = x.permute(0, 2, 1)
|
||||
# x: [Batch Variate Time]
|
||||
if x_mark is None:
|
||||
x = self.value_embedding(x)
|
||||
else:
|
||||
# the potential to take covariates (e.g. timestamps) as tokens
|
||||
x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
|
||||
# x: [Batch Variate d_model]
|
||||
return self.dropout(x)
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from math import sqrt
|
||||
|
||||
|
||||
class FullAttention(nn.Module):
|
||||
def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False):
|
||||
super(FullAttention, self).__init__()
|
||||
self.scale = scale
|
||||
self.mask_flag = mask_flag
|
||||
self.output_attention = output_attention
|
||||
self.dropout = nn.Dropout(attention_dropout)
|
||||
|
||||
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
||||
B, L, H, E = queries.shape
|
||||
_, S, _, D = values.shape
|
||||
scale = self.scale or 1. / sqrt(E)
|
||||
|
||||
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
|
||||
|
||||
if self.mask_flag:
|
||||
if attn_mask is None:
|
||||
attn_mask = TriangularCausalMask(B, L, device=queries.device)
|
||||
|
||||
scores.masked_fill_(attn_mask.mask, -np.inf)
|
||||
|
||||
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
||||
V = torch.einsum("bhls,bshd->blhd", A, values)
|
||||
|
||||
if self.output_attention:
|
||||
return V.contiguous(), A
|
||||
else:
|
||||
return V.contiguous(), None
|
||||
|
||||
class AttentionLayer(nn.Module):
|
||||
def __init__(self, attention, d_model, n_heads, d_keys=None,
|
||||
d_values=None):
|
||||
super(AttentionLayer, self).__init__()
|
||||
|
||||
d_keys = d_keys or (d_model // n_heads)
|
||||
d_values = d_values or (d_model // n_heads)
|
||||
|
||||
self.inner_attention = attention
|
||||
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
||||
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
||||
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
||||
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
|
||||
B, L, _ = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
H = self.n_heads
|
||||
|
||||
queries = self.query_projection(queries).view(B, L, H, -1)
|
||||
keys = self.key_projection(keys).view(B, S, H, -1)
|
||||
values = self.value_projection(values).view(B, S, H, -1)
|
||||
|
||||
out, attn = self.inner_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attn_mask,
|
||||
tau=tau,
|
||||
delta=delta
|
||||
)
|
||||
out = out.view(B, L, -1)
|
||||
|
||||
return self.out_projection(out), attn
|
||||
|
||||
|
||||
class TriangularCausalMask:
|
||||
def __init__(self, B, L, device="cpu"):
|
||||
mask_shape = [B, 1, L, L]
|
||||
with torch.no_grad():
|
||||
self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
|
||||
|
||||
@property
|
||||
def mask(self):
|
||||
return self._mask
|
||||
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
|
||||
super(EncoderLayer, self).__init__()
|
||||
d_ff = d_ff or 4 * d_model
|
||||
self.attention = attention
|
||||
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
||||
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.activation = F.relu if activation == "relu" else F.gelu
|
||||
|
||||
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
||||
new_x, attn = self.attention(
|
||||
x, x, x,
|
||||
attn_mask=attn_mask,
|
||||
tau=tau, delta=delta
|
||||
)
|
||||
x = x + self.dropout(new_x)
|
||||
|
||||
y = x = self.norm1(x)
|
||||
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
||||
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
||||
|
||||
return self.norm2(x + y), attn
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
|
||||
super(Encoder, self).__init__()
|
||||
self.attn_layers = nn.ModuleList(attn_layers)
|
||||
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
|
||||
self.norm = norm_layer
|
||||
|
||||
def forward(self, x, attn_mask=None, tau=None, delta=None):
|
||||
# x [B, L, D]
|
||||
attns = []
|
||||
if self.conv_layers is not None:
|
||||
for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
|
||||
delta = delta if i == 0 else None
|
||||
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
||||
x = conv_layer(x)
|
||||
attns.append(attn)
|
||||
x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
|
||||
attns.append(attn)
|
||||
else:
|
||||
for attn_layer in self.attn_layers:
|
||||
x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
|
||||
attns.append(attn)
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, attns
|
||||
|
|
@ -26,6 +26,7 @@ from model.REPST.repst import repst as REPST
|
|||
from model.ASTRA.astra import ASTRA as ASTRA
|
||||
from model.ASTRA.astrav2 import ASTRA as ASTRAv2
|
||||
from model.ASTRA.astrav3 import ASTRA as ASTRAv3
|
||||
from model.iTransformer.iTransformer import iTransformer
|
||||
|
||||
|
||||
|
||||
|
|
@ -89,3 +90,5 @@ def model_selector(config):
|
|||
return ASTRAv2(model_config)
|
||||
case "ASTRA_v3":
|
||||
return ASTRAv3(model_config)
|
||||
case "iTransformer":
|
||||
return iTransformer(model_config)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
numpy
|
||||
pyyaml
|
||||
tqdm
|
||||
statsmodels
|
||||
|
|
|
|||
26
run.py
26
run.py
|
|
@ -11,36 +11,28 @@ from trainer.trainer_selector import select_trainer
|
|||
|
||||
|
||||
def main():
|
||||
# 读取配置
|
||||
args = parse_args()
|
||||
|
||||
# 初始化 device, seed, model, data, trainer
|
||||
args = init.init_device(args)
|
||||
init.init_seed(args["basic"]["seed"])
|
||||
|
||||
# Load model
|
||||
model = init.init_model(args)
|
||||
|
||||
# Load dataset
|
||||
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
|
||||
args, normalizer=args["data"]["normalizer"], single=False
|
||||
)
|
||||
|
||||
loss = init.init_loss(args, scaler)
|
||||
optimizer, lr_scheduler = init.init_optimizer(model, args["train"])
|
||||
init.create_logs(args)
|
||||
|
||||
# Start training or testing
|
||||
trainer = select_trainer(
|
||||
model,
|
||||
loss,
|
||||
optimizer,
|
||||
train_loader,
|
||||
val_loader,
|
||||
test_loader,
|
||||
scaler,
|
||||
loss, optimizer,
|
||||
train_loader, val_loader, test_loader, scaler,
|
||||
args,
|
||||
lr_scheduler,
|
||||
extra_data,
|
||||
lr_scheduler, extra_data,
|
||||
)
|
||||
|
||||
# 开始训练
|
||||
match args["basic"]["mode"]:
|
||||
case "train":
|
||||
trainer.train()
|
||||
|
|
@ -54,9 +46,7 @@ def main():
|
|||
)
|
||||
trainer.test(
|
||||
model.to(args["basic"]["device"]),
|
||||
trainer.args,
|
||||
test_loader,
|
||||
scaler,
|
||||
trainer.args, test_loader, scaler,
|
||||
trainer.logger,
|
||||
)
|
||||
case _:
|
||||
|
|
|
|||
Loading…
Reference in New Issue