98 lines
3.5 KiB
Python
98 lines
3.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
|
|
class ReplicationPad1d(nn.Module):
|
|
def __init__(self, padding) -> None:
|
|
super(ReplicationPad1d, self).__init__()
|
|
self.padding = padding
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
replicate_padding = input[:, :, :, -1].unsqueeze(-1).repeat(1, 1, 1, self.padding[-1])
|
|
output = torch.cat([input, replicate_padding], dim=-1)
|
|
return output
|
|
|
|
class TokenEmbedding(nn.Module):
|
|
def __init__(self, c_in, d_model):
|
|
super(TokenEmbedding, self).__init__()
|
|
padding = 1
|
|
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
|
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
|
self.confusion_layer = nn.Linear(12, 1)
|
|
# if air_quality
|
|
# self.confusion_layer = nn.Linear(42, 1)
|
|
|
|
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv1d):
|
|
nn.init.kaiming_normal_(
|
|
m.weight, mode='fan_in', nonlinearity='leaky_relu')
|
|
|
|
def forward(self, x):
|
|
b, n, m, pn, pl = x.shape
|
|
x = self.tokenConv(x.reshape(b*n, pl, m*pn))
|
|
|
|
x = self.confusion_layer(x)
|
|
return x.reshape(b, n, -1)
|
|
|
|
|
|
class PatchEmbedding(nn.Module):
|
|
def __init__(self, d_model, patch_len, stride, dropout):
|
|
super(PatchEmbedding, self).__init__()
|
|
# Patching
|
|
self.patch_len = patch_len
|
|
self.stride = stride
|
|
self.padding_patch_layer = ReplicationPad1d((0, stride))
|
|
self.value_embedding = TokenEmbedding(patch_len, d_model)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, x):
|
|
|
|
n_vars = x.shape[2]
|
|
x = self.padding_patch_layer(x)
|
|
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
|
x_value_embed = self.value_embedding(x)
|
|
|
|
return self.dropout(x_value_embed), n_vars
|
|
|
|
class ReprogrammingLayer(nn.Module):
|
|
def __init__(self, d_model, n_heads, d_keys=None, d_llm=None, attention_dropout=0.1):
|
|
super(ReprogrammingLayer, self).__init__()
|
|
|
|
d_keys = d_keys or (d_model // n_heads)
|
|
|
|
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
|
self.key_projection = nn.Linear(d_llm, d_keys * n_heads)
|
|
self.value_projection = nn.Linear(d_llm, d_keys * n_heads)
|
|
self.out_projection = nn.Linear(d_keys * n_heads, d_llm)
|
|
self.n_heads = n_heads
|
|
self.dropout = nn.Dropout(attention_dropout)
|
|
|
|
def forward(self, target_embedding, source_embedding, value_embedding):
|
|
B, L, _ = target_embedding.shape
|
|
S, _ = source_embedding.shape
|
|
H = self.n_heads
|
|
|
|
target_embedding = self.query_projection(target_embedding).view(B, L, H, -1)
|
|
source_embedding = self.key_projection(source_embedding).view(S, H, -1)
|
|
value_embedding = self.value_projection(value_embedding).view(S, H, -1)
|
|
|
|
out = self.reprogramming(target_embedding, source_embedding, value_embedding)
|
|
out = out.reshape(B, L, -1)
|
|
|
|
return self.out_projection(out)
|
|
|
|
def reprogramming(self, target_embedding, source_embedding, value_embedding):
|
|
B, L, H, E = target_embedding.shape
|
|
|
|
scale = 1. / sqrt(E)
|
|
|
|
scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)
|
|
|
|
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
|
reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)
|
|
|
|
return reprogramming_embedding
|
|
|