FS-TFP/federatedscope/gfl/gcflplus/utils.py

35 lines
966 B
Python

import torch
import numpy as np
import networkx as nx
from dtaidistance import dtw
"""
Utils from: https://github.com/Oxfordblue7/GCFL
"""
def norm(w):
return torch.norm(torch.cat([v.flatten() for v in w.values()])).item()
def compute_pairwise_distances(seqs, standardize=False):
""" computes DTW distances for gcfl+"""
if standardize:
# standardize to only focus on the trends
seqs = np.array(seqs)
seqs = seqs / seqs.std(axis=1).reshape(-1, 1)
distances = dtw.distance_matrix(seqs)
else:
distances = dtw.distance_matrix(seqs)
return distances
def min_cut(similarity, cluster):
g = nx.Graph()
for i in range(len(similarity)):
for j in range(len(similarity)):
g.add_edge(i, j, weight=similarity[i][j])
cut, partition = nx.stoer_wagner(g)
c1 = np.array([cluster[x] for x in partition[0]])
c2 = np.array([cluster[x] for x in partition[1]])
return c1, c2