import torch import torch.nn as nn from torch import Tensor from math import sqrt class ReplicationPad1d(nn.Module): def __init__(self, padding) -> None: super(ReplicationPad1d, self).__init__() self.padding = padding def forward(self, input: Tensor) -> Tensor: replicate_padding = input[:, :, :, -1].unsqueeze(-1).repeat(1, 1, 1, self.padding[-1]) output = torch.cat([input, replicate_padding], dim=-1) return output class TokenEmbedding(nn.Module): def __init__(self, c_in, d_model, patch_num, input_dim): super(TokenEmbedding, self).__init__() padding = 1 self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, kernel_size=3, padding=padding, padding_mode='circular', bias=False) self.confusion_layer = nn.Linear(patch_num * input_dim, 1) # if air_quality # self.confusion_layer = nn.Linear(42, 1) for m in self.modules(): if isinstance(m, nn.Conv1d): nn.init.kaiming_normal_( m.weight, mode='fan_in', nonlinearity='leaky_relu') def forward(self, x): b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len # 768,64,25 x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num x = self.confusion_layer(x) return x.reshape(b, n, -1) class PatchEmbedding(nn.Module): def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim): super(PatchEmbedding, self).__init__() # Patching self.patch_len = patch_len self.stride = stride self.padding_patch_layer = ReplicationPad1d((0, stride)) self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): n_vars = x.shape[2] x = self.padding_patch_layer(x) x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) x_value_embed = self.value_embedding(x) return self.dropout(x_value_embed), n_vars class ReprogrammingLayer(nn.Module): def __init__(self, d_model, n_heads, d_keys=None, d_llm=None, attention_dropout=0.1): super(ReprogrammingLayer, self).__init__() d_keys = d_keys or (d_model // n_heads) self.query_projection = nn.Linear(d_model, d_keys * n_heads) self.key_projection = nn.Linear(d_llm, d_keys * n_heads) self.value_projection = nn.Linear(d_llm, d_keys * n_heads) self.out_projection = nn.Linear(d_keys * n_heads, d_llm) self.n_heads = n_heads self.dropout = nn.Dropout(attention_dropout) def forward(self, target_embedding, source_embedding, value_embedding): B, L, _ = target_embedding.shape S, _ = source_embedding.shape H = self.n_heads target_embedding = self.query_projection(target_embedding).view(B, L, H, -1) source_embedding = self.key_projection(source_embedding).view(S, H, -1) value_embedding = self.value_projection(value_embedding).view(S, H, -1) out = self.reprogramming(target_embedding, source_embedding, value_embedding) out = out.reshape(B, L, -1) return self.out_projection(out) def reprogramming(self, target_embedding, source_embedding, value_embedding): B, L, H, E = target_embedding.shape scale = 1. / sqrt(E) scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding) A = self.dropout(torch.softmax(scale * scores, dim=-1)) reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding) return reprogramming_embedding