impl Informer
This commit is contained in:
parent
b46c16815e
commit
c121912f03
|
|
@ -6,7 +6,7 @@ basic:
|
|||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 256
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
|
|
@ -43,7 +43,7 @@ model:
|
|||
|
||||
|
||||
train:
|
||||
batch_size: 256
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ basic:
|
|||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 2048
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
|
|
@ -43,7 +43,7 @@ model:
|
|||
|
||||
|
||||
train:
|
||||
batch_size: 2048
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ basic:
|
|||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 2048
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
|
|
@ -43,7 +43,7 @@ model:
|
|||
|
||||
|
||||
train:
|
||||
batch_size: 2048
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ basic:
|
|||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 256
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
|
|
@ -22,7 +22,7 @@ data:
|
|||
model:
|
||||
activation: gelu
|
||||
seq_len: 24
|
||||
label_len: 12
|
||||
label_len: 24
|
||||
pred_len: 24
|
||||
d_model: 128
|
||||
d_ff: 2048
|
||||
|
|
@ -43,7 +43,7 @@ model:
|
|||
|
||||
|
||||
train:
|
||||
batch_size: 256
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ basic:
|
|||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 256
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
|
|
@ -43,7 +43,7 @@ model:
|
|||
|
||||
|
||||
train:
|
||||
batch_size: 256
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ basic:
|
|||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 256
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
|
|
@ -43,7 +43,7 @@ model:
|
|||
|
||||
|
||||
train:
|
||||
batch_size: 256
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ basic:
|
|||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 2048
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
|
|
@ -43,7 +43,7 @@ model:
|
|||
|
||||
|
||||
train:
|
||||
batch_size: 2048
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ basic:
|
|||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 1024
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
|
|
@ -43,7 +43,7 @@ model:
|
|||
|
||||
|
||||
train:
|
||||
batch_size: 1024
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
|
|
|
|||
|
|
@ -1,129 +1,36 @@
|
|||
# model/InformerOnlyX/embed.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import math
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, d_model, max_len=5000):
|
||||
super(PositionalEmbedding, self).__init__()
|
||||
# Compute the positional encodings once in log space.
|
||||
pe = torch.zeros(max_len, d_model).float()
|
||||
pe.require_grad = False
|
||||
|
||||
position = torch.arange(0, max_len).float().unsqueeze(1)
|
||||
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
|
||||
|
||||
super().__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len).unsqueeze(1).float()
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer('pe', pe)
|
||||
self.register_buffer("pe", pe.unsqueeze(0)) # [1, L, D]
|
||||
|
||||
def forward(self, x):
|
||||
return self.pe[:, :x.size(1)]
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
super(TokenEmbedding, self).__init__()
|
||||
padding = 1 if torch.__version__>='1.5.0' else 2
|
||||
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
||||
kernel_size=3, padding=padding, padding_mode='circular')
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2)
|
||||
return x
|
||||
|
||||
class FixedEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
super(FixedEmbedding, self).__init__()
|
||||
|
||||
w = torch.zeros(c_in, d_model).float()
|
||||
w.require_grad = False
|
||||
|
||||
position = torch.arange(0, c_in).float().unsqueeze(1)
|
||||
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
|
||||
|
||||
w[:, 0::2] = torch.sin(position * div_term)
|
||||
w[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
self.emb = nn.Embedding(c_in, d_model)
|
||||
self.emb.weight = nn.Parameter(w, requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.emb(x).detach()
|
||||
|
||||
class TemporalEmbedding(nn.Module):
|
||||
def __init__(self, d_model, embed_type='fixed', freq='h'):
|
||||
super(TemporalEmbedding, self).__init__()
|
||||
|
||||
minute_size = 4; hour_size = 24
|
||||
weekday_size = 7; day_size = 32; month_size = 13
|
||||
|
||||
Embed = FixedEmbedding if embed_type=='fixed' else nn.Embedding
|
||||
if freq=='t':
|
||||
self.minute_embed = Embed(minute_size, d_model)
|
||||
self.hour_embed = Embed(hour_size, d_model)
|
||||
self.weekday_embed = Embed(weekday_size, d_model)
|
||||
self.day_embed = Embed(day_size, d_model)
|
||||
self.month_embed = Embed(month_size, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.long()
|
||||
|
||||
# Check the size of x's last dimension to avoid index errors
|
||||
last_dim = x.shape[-1]
|
||||
|
||||
minute_x = 0.
|
||||
hour_x = 0.
|
||||
weekday_x = 0.
|
||||
day_x = 0.
|
||||
month_x = 0.
|
||||
|
||||
# For our generated time features, we have only 2 dimensions: [day_of_week, hour]
|
||||
# So we need to map them to the appropriate embedding layers
|
||||
if last_dim > 0:
|
||||
# Use the first dimension for hour
|
||||
# Ensure hour is in the valid range [0, 23]
|
||||
hour = torch.clamp(x[:,:,0], 0, 23)
|
||||
hour_x = self.hour_embed(hour)
|
||||
|
||||
if last_dim > 1:
|
||||
# Use the second dimension for weekday
|
||||
# Ensure weekday is in the valid range [0, 6]
|
||||
weekday = torch.clamp(x[:,:,1], 0, 6)
|
||||
weekday_x = self.weekday_embed(weekday)
|
||||
|
||||
return hour_x + weekday_x + day_x + month_x + minute_x
|
||||
|
||||
class TimeFeatureEmbedding(nn.Module):
|
||||
def __init__(self, d_model, embed_type='timeF', freq='h'):
|
||||
super(TimeFeatureEmbedding, self).__init__()
|
||||
|
||||
freq_map = {'h':4, 't':5, 's':6, 'm':1, 'a':1, 'w':2, 'd':3, 'b':3}
|
||||
d_inp = freq_map[freq]
|
||||
self.embed = nn.Linear(d_inp, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embed(x)
|
||||
|
||||
class DataEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
|
||||
super(DataEmbedding, self).__init__()
|
||||
"""
|
||||
Informer-style embedding without time covariates
|
||||
"""
|
||||
|
||||
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
||||
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
||||
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type!='timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
|
||||
def __init__(self, c_in, d_model, dropout):
|
||||
super().__init__()
|
||||
self.value_embedding = nn.Linear(c_in, d_model)
|
||||
self.position_embedding = PositionalEmbedding(d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
a = self.value_embedding(x)
|
||||
b = self.position_embedding(x)
|
||||
c = self.temporal_embedding(x_mark)
|
||||
x = a + b + c
|
||||
|
||||
return self.dropout(x)
|
||||
def forward(self, x):
|
||||
x = self.value_embedding(x) + self.position_embedding(x)
|
||||
return self.dropout(x)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,25 @@
|
|||
# model/Informer/head.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class TemporalProjectionHead(nn.Module):
|
||||
"""
|
||||
Project along temporal dimension
|
||||
[B, L, D] -> [B, pred_len, C]
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, pred_len, c_out):
|
||||
super().__init__()
|
||||
self.temporal_proj = nn.Linear(1, pred_len)
|
||||
self.channel_proj = nn.Linear(d_model, c_out)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, L, D]
|
||||
# Average over the sequence dimension and then project
|
||||
x = x.mean(dim=1, keepdim=True) # [B, 1, D]
|
||||
x = x.transpose(1, 2) # [B, D, 1]
|
||||
x = self.temporal_proj(x) # [B, D, pred_len]
|
||||
x = x.transpose(1, 2) # [B, pred_len, D]
|
||||
x = self.channel_proj(x) # [B, pred_len, C]
|
||||
return x
|
||||
|
|
@ -1,141 +1,79 @@
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model.Informer.encoder import Encoder, EncoderLayer, ConvLayer, EncoderStack
|
||||
from model.Informer.decoder import Decoder, DecoderLayer
|
||||
from model.Informer.encoder import Encoder, EncoderLayer, ConvLayer
|
||||
from model.Informer.attn import FullAttention, ProbAttention, AttentionLayer
|
||||
from model.Informer.embed import DataEmbedding
|
||||
from model.Informer.head import TemporalProjectionHead
|
||||
|
||||
class Informer(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(Informer, self).__init__()
|
||||
self.pred_len = args['pred_len']
|
||||
self.attn = args['attn']
|
||||
self.output_attention = args['output_attention']
|
||||
|
||||
# Encoding
|
||||
self.enc_embedding = DataEmbedding(args['enc_in'], args['d_model'], args['embed'], args['freq'], args['dropout'])
|
||||
self.dec_embedding = DataEmbedding(args['dec_in'], args['d_model'], args['embed'], args['freq'], args['dropout'])
|
||||
# Attention
|
||||
Attn = ProbAttention if args['attn']=='prob' else FullAttention
|
||||
# Encoder
|
||||
class InformerEncoder(nn.Module):
|
||||
"""
|
||||
Informer Encoder-only
|
||||
- Only uses x
|
||||
- No normalization
|
||||
- Multi-channel friendly
|
||||
"""
|
||||
|
||||
def __init__(self, configs):
|
||||
super().__init__()
|
||||
|
||||
self.seq_len = configs["seq_len"]
|
||||
self.pred_len = configs["pred_len"]
|
||||
|
||||
Attn = ProbAttention if configs["attn"] == "prob" else FullAttention
|
||||
|
||||
# Embedding
|
||||
self.embedding = DataEmbedding(
|
||||
configs["enc_in"],
|
||||
configs["d_model"],
|
||||
configs["dropout"],
|
||||
)
|
||||
|
||||
# Encoder (Informer)
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AttentionLayer(Attn(False, args['factor'], attention_dropout=args['dropout'], output_attention=args['output_attention']),
|
||||
args['d_model'], args['n_heads'], mix=False),
|
||||
args['d_model'],
|
||||
args['d_ff'],
|
||||
dropout=args['dropout'],
|
||||
activation=args['activation']
|
||||
) for l in range(args['e_layers'])
|
||||
],
|
||||
[
|
||||
ConvLayer(
|
||||
args['d_model']
|
||||
) for l in range(args['e_layers']-1)
|
||||
] if args['distil'] else None,
|
||||
norm_layer=torch.nn.LayerNorm(args['d_model'])
|
||||
)
|
||||
# Decoder
|
||||
self.decoder = Decoder(
|
||||
[
|
||||
DecoderLayer(
|
||||
AttentionLayer(Attn(True, args['factor'], attention_dropout=args['dropout'], output_attention=False),
|
||||
args['d_model'], args['n_heads'], mix=args['mix']),
|
||||
AttentionLayer(FullAttention(False, args['factor'], attention_dropout=args['dropout'], output_attention=False),
|
||||
args['d_model'], args['n_heads'], mix=False),
|
||||
args['d_model'],
|
||||
args['d_ff'],
|
||||
dropout=args['dropout'],
|
||||
activation=args['activation'],
|
||||
AttentionLayer(
|
||||
Attn(
|
||||
False,
|
||||
configs["factor"],
|
||||
attention_dropout=configs["dropout"],
|
||||
output_attention=False,
|
||||
),
|
||||
configs["d_model"],
|
||||
configs["n_heads"],
|
||||
mix=False,
|
||||
),
|
||||
configs["d_model"],
|
||||
configs["d_ff"],
|
||||
dropout=configs["dropout"],
|
||||
activation=configs["activation"],
|
||||
)
|
||||
for l in range(args['d_layers'])
|
||||
for _ in range(configs["e_layers"])
|
||||
],
|
||||
norm_layer=torch.nn.LayerNorm(args['d_model'])
|
||||
)
|
||||
self.projection = nn.Linear(args['d_model'], args['c_out'], bias=True)
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
|
||||
enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
|
||||
|
||||
dec_out = self.dec_embedding(x_dec, x_mark_dec)
|
||||
dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
|
||||
dec_out = self.projection(dec_out)
|
||||
|
||||
if self.output_attention:
|
||||
return dec_out[:,-self.pred_len:,:], attns
|
||||
else:
|
||||
return dec_out[:,-self.pred_len:,:] # [B, L, D]
|
||||
|
||||
|
||||
class InformerStack(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(InformerStack, self).__init__()
|
||||
self.pred_len = args['pred_len']
|
||||
self.attn = args['attn']
|
||||
self.output_attention = args['output_attention']
|
||||
|
||||
# Encoding
|
||||
self.enc_embedding = DataEmbedding(args['enc_in'], args['d_model'], args['embed'], args['freq'], args['dropout'])
|
||||
self.dec_embedding = DataEmbedding(args['dec_in'], args['d_model'], args['embed'], args['freq'], args['dropout'])
|
||||
# Attention
|
||||
Attn = ProbAttention if args['attn']=='prob' else FullAttention
|
||||
# Encoder
|
||||
|
||||
inp_lens = list(range(len(args['e_layers']))) # [0,1,2,...] you can customize here
|
||||
encoders = [
|
||||
Encoder(
|
||||
[
|
||||
EncoderLayer(
|
||||
AttentionLayer(Attn(False, args['factor'], attention_dropout=args['dropout'], output_attention=args['output_attention']),
|
||||
args['d_model'], args['n_heads'], mix=False),
|
||||
args['d_model'],
|
||||
args['d_ff'],
|
||||
dropout=args['dropout'],
|
||||
activation=args['activation']
|
||||
) for l in range(el)
|
||||
],
|
||||
[
|
||||
ConvLayer(
|
||||
args['d_model']
|
||||
) for l in range(el-1)
|
||||
] if args['distil'] else None,
|
||||
norm_layer=torch.nn.LayerNorm(args['d_model'])
|
||||
) for el in args['e_layers']]
|
||||
self.encoder = EncoderStack(encoders, inp_lens)
|
||||
# Decoder
|
||||
self.decoder = Decoder(
|
||||
[
|
||||
DecoderLayer(
|
||||
AttentionLayer(Attn(True, args['factor'], attention_dropout=args['dropout'], output_attention=False),
|
||||
args['d_model'], args['n_heads'], mix=args['mix']),
|
||||
AttentionLayer(FullAttention(False, args['factor'], attention_dropout=args['dropout'], output_attention=False),
|
||||
args['d_model'], args['n_heads'], mix=False),
|
||||
args['d_model'],
|
||||
args['d_ff'],
|
||||
dropout=args['dropout'],
|
||||
activation=args['activation'],
|
||||
)
|
||||
for l in range(args['d_layers'])
|
||||
],
|
||||
norm_layer=torch.nn.LayerNorm(args['d_model'])
|
||||
ConvLayer(configs["d_model"])
|
||||
for _ in range(configs["e_layers"] - 1)
|
||||
]
|
||||
if configs.get("distil", False)
|
||||
else None,
|
||||
norm_layer=nn.LayerNorm(configs["d_model"]),
|
||||
)
|
||||
self.projection = nn.Linear(args['d_model'], args['c_out'], bias=True)
|
||||
|
||||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
|
||||
enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
|
||||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||||
enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
|
||||
|
||||
dec_out = self.dec_embedding(x_dec, x_mark_dec)
|
||||
dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
|
||||
dec_out = self.projection(dec_out)
|
||||
|
||||
if self.output_attention:
|
||||
return dec_out[:,-self.pred_len:,:], attns
|
||||
else:
|
||||
return dec_out[:,-self.pred_len:,:] # [B, L, D]
|
||||
# Forecast Head
|
||||
self.head = TemporalProjectionHead(
|
||||
d_model=configs["d_model"],
|
||||
pred_len=configs["pred_len"],
|
||||
c_out=configs["c_out"],
|
||||
)
|
||||
|
||||
def forward(self, x_enc):
|
||||
"""
|
||||
x_enc: [B, L, C]
|
||||
"""
|
||||
x = self.embedding(x_enc)
|
||||
x, _ = self.encoder(x)
|
||||
out = self.head(x)
|
||||
return out[:, -self.pred_len :, :]
|
||||
|
|
|
|||
|
|
@ -2,6 +2,6 @@
|
|||
{
|
||||
"name": "Informer",
|
||||
"module": "model.Informer.model",
|
||||
"entry": "Informer"
|
||||
"entry": "InformerEncoder"
|
||||
}
|
||||
]
|
||||
8
train.py
8
train.py
|
|
@ -12,7 +12,7 @@ def read_config(config_path):
|
|||
config = yaml.safe_load(file)
|
||||
|
||||
# 全局配置
|
||||
device = "cpu" # 指定设备为cuda:0
|
||||
device = "cuda:0" # 指定设备为cuda:0
|
||||
seed = 2023 # 随机种子
|
||||
epochs = 1 # 训练轮数
|
||||
|
||||
|
|
@ -102,10 +102,10 @@ def main(model_list, data, debug=False):
|
|||
if __name__ == "__main__":
|
||||
# 调试用
|
||||
# model_list = ["iTransformer", "PatchTST", "HI"]
|
||||
model_list = ["D2STGNN"]
|
||||
model_list = ["Informer"]
|
||||
# model_list = ["PatchTST"]
|
||||
# dataset_list = ["AirQuality"]
|
||||
# dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"]
|
||||
# dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
|
||||
dataset_list = ["BJTaxi-OutFlow"]
|
||||
dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
|
||||
# dataset_list = ["METR-LA"]
|
||||
main(model_list, dataset_list, debug=True)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ def select_trainer(
|
|||
scaler, args, lr_scheduler
|
||||
)
|
||||
|
||||
if model_name in {"HI", "PatchTST", "iTransformer", "FPT"}:
|
||||
if model_name in {"HI", "PatchTST", "iTransformer", "FPT", "Informer"}:
|
||||
return TSTrainer(*base_args)
|
||||
|
||||
trainer_map = {
|
||||
|
|
@ -28,7 +28,6 @@ def select_trainer(
|
|||
"PDG2SEQ": PDG2SEQ_Trainer,
|
||||
"STMLP": STMLP_Trainer,
|
||||
"EXP": EXP_Trainer,
|
||||
"Informer": InformerTrainer,
|
||||
}
|
||||
|
||||
if model_name in {"STGNCDE", "STGNRDE"}:
|
||||
|
|
|
|||
Loading…
Reference in New Issue