129 lines
4.6 KiB
Python
129 lines
4.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import math
|
|
|
|
class PositionalEmbedding(nn.Module):
|
|
def __init__(self, d_model, max_len=5000):
|
|
super(PositionalEmbedding, self).__init__()
|
|
# Compute the positional encodings once in log space.
|
|
pe = torch.zeros(max_len, d_model).float()
|
|
pe.require_grad = False
|
|
|
|
position = torch.arange(0, max_len).float().unsqueeze(1)
|
|
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
|
|
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
|
|
pe = pe.unsqueeze(0)
|
|
self.register_buffer('pe', pe)
|
|
|
|
def forward(self, x):
|
|
return self.pe[:, :x.size(1)]
|
|
|
|
class TokenEmbedding(nn.Module):
|
|
def __init__(self, c_in, d_model):
|
|
super(TokenEmbedding, self).__init__()
|
|
padding = 1 if torch.__version__>='1.5.0' else 2
|
|
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
|
kernel_size=3, padding=padding, padding_mode='circular')
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv1d):
|
|
nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu')
|
|
|
|
def forward(self, x):
|
|
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2)
|
|
return x
|
|
|
|
class FixedEmbedding(nn.Module):
|
|
def __init__(self, c_in, d_model):
|
|
super(FixedEmbedding, self).__init__()
|
|
|
|
w = torch.zeros(c_in, d_model).float()
|
|
w.require_grad = False
|
|
|
|
position = torch.arange(0, c_in).float().unsqueeze(1)
|
|
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
|
|
|
|
w[:, 0::2] = torch.sin(position * div_term)
|
|
w[:, 1::2] = torch.cos(position * div_term)
|
|
|
|
self.emb = nn.Embedding(c_in, d_model)
|
|
self.emb.weight = nn.Parameter(w, requires_grad=False)
|
|
|
|
def forward(self, x):
|
|
return self.emb(x).detach()
|
|
|
|
class TemporalEmbedding(nn.Module):
|
|
def __init__(self, d_model, embed_type='fixed', freq='h'):
|
|
super(TemporalEmbedding, self).__init__()
|
|
|
|
minute_size = 4; hour_size = 24
|
|
weekday_size = 7; day_size = 32; month_size = 13
|
|
|
|
Embed = FixedEmbedding if embed_type=='fixed' else nn.Embedding
|
|
if freq=='t':
|
|
self.minute_embed = Embed(minute_size, d_model)
|
|
self.hour_embed = Embed(hour_size, d_model)
|
|
self.weekday_embed = Embed(weekday_size, d_model)
|
|
self.day_embed = Embed(day_size, d_model)
|
|
self.month_embed = Embed(month_size, d_model)
|
|
|
|
def forward(self, x):
|
|
x = x.long()
|
|
|
|
# Check the size of x's last dimension to avoid index errors
|
|
last_dim = x.shape[-1]
|
|
|
|
minute_x = 0.
|
|
hour_x = 0.
|
|
weekday_x = 0.
|
|
day_x = 0.
|
|
month_x = 0.
|
|
|
|
# For our generated time features, we have only 2 dimensions: [day_of_week, hour]
|
|
# So we need to map them to the appropriate embedding layers
|
|
if last_dim > 0:
|
|
# Use the first dimension for hour
|
|
# Ensure hour is in the valid range [0, 23]
|
|
hour = torch.clamp(x[:,:,0], 0, 23)
|
|
hour_x = self.hour_embed(hour)
|
|
|
|
if last_dim > 1:
|
|
# Use the second dimension for weekday
|
|
# Ensure weekday is in the valid range [0, 6]
|
|
weekday = torch.clamp(x[:,:,1], 0, 6)
|
|
weekday_x = self.weekday_embed(weekday)
|
|
|
|
return hour_x + weekday_x + day_x + month_x + minute_x
|
|
|
|
class TimeFeatureEmbedding(nn.Module):
|
|
def __init__(self, d_model, embed_type='timeF', freq='h'):
|
|
super(TimeFeatureEmbedding, self).__init__()
|
|
|
|
freq_map = {'h':4, 't':5, 's':6, 'm':1, 'a':1, 'w':2, 'd':3, 'b':3}
|
|
d_inp = freq_map[freq]
|
|
self.embed = nn.Linear(d_inp, d_model)
|
|
|
|
def forward(self, x):
|
|
return self.embed(x)
|
|
|
|
class DataEmbedding(nn.Module):
|
|
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
|
|
super(DataEmbedding, self).__init__()
|
|
|
|
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
|
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
|
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type!='timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
|
|
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
|
def forward(self, x, x_mark):
|
|
a = self.value_embedding(x)
|
|
b = self.position_embedding(x)
|
|
c = self.temporal_embedding(x_mark)
|
|
x = a + b + c
|
|
|
|
return self.dropout(x) |