TrafficWheel/model/STEP/similarity.py

22 lines
738 B
Python

import math
import torch
def batch_cosine_similarity(x, y):
# 计算分母
l2_x = torch.norm(x, dim=2, p=2) + 1e-7 # avoid 0, l2 norm, num_heads x batch_size x hidden_dim==>num_heads x batch_size
l2_y = torch.norm(y, dim=2, p=2) + 1e-7 # avoid 0, l2 norm, num_heads x batch_size x hidden_dim==>num_heads x batch_size
l2_m = torch.matmul(l2_x.unsqueeze(dim=2), l2_y.unsqueeze(dim=2).transpose(1, 2))
# 计算分子
l2_z = torch.matmul(x, y.transpose(1, 2))
# cos similarity affinity matrix
cos_affnity = l2_z / l2_m
adj = cos_affnity
return adj
def batch_dot_similarity(x, y):
QKT = torch.bmm(x, y.transpose(-1, -2)) / math.sqrt(x.shape[2])
W = torch.softmax(QKT, dim=-1)
return W