22 lines
738 B
Python
22 lines
738 B
Python
import math
|
|
|
|
import torch
|
|
|
|
|
|
def batch_cosine_similarity(x, y):
|
|
# 计算分母
|
|
l2_x = torch.norm(x, dim=2, p=2) + 1e-7 # avoid 0, l2 norm, num_heads x batch_size x hidden_dim==>num_heads x batch_size
|
|
l2_y = torch.norm(y, dim=2, p=2) + 1e-7 # avoid 0, l2 norm, num_heads x batch_size x hidden_dim==>num_heads x batch_size
|
|
l2_m = torch.matmul(l2_x.unsqueeze(dim=2), l2_y.unsqueeze(dim=2).transpose(1, 2))
|
|
# 计算分子
|
|
l2_z = torch.matmul(x, y.transpose(1, 2))
|
|
# cos similarity affinity matrix
|
|
cos_affnity = l2_z / l2_m
|
|
adj = cos_affnity
|
|
return adj
|
|
|
|
def batch_dot_similarity(x, y):
|
|
QKT = torch.bmm(x, y.transpose(-1, -2)) / math.sqrt(x.shape[2])
|
|
W = torch.softmax(QKT, dim=-1)
|
|
return W
|