import torch import numpy as np import networkx as nx from dtaidistance import dtw """ Utils from: https://github.com/Oxfordblue7/GCFL """ def norm(w): return torch.norm(torch.cat([v.flatten() for v in w.values()])).item() def compute_pairwise_distances(seqs, standardize=False): """ computes DTW distances for gcfl+""" if standardize: # standardize to only focus on the trends seqs = np.array(seqs) seqs = seqs / seqs.std(axis=1).reshape(-1, 1) distances = dtw.distance_matrix(seqs) else: distances = dtw.distance_matrix(seqs) return distances def min_cut(similarity, cluster): g = nx.Graph() for i in range(len(similarity)): for j in range(len(similarity)): g.add_edge(i, j, weight=similarity[i][j]) cut, partition = nx.stoer_wagner(g) c1 = np.array([cluster[x] for x in partition[0]]) c2 = np.array([cluster[x] for x in partition[1]]) return c1, c2