TrafficWheel/model/ST_SSL/aug.py

104 lines
4.1 KiB
Python

import copy
import numpy as np
import torch
def sim_global(flow_data, sim_type='cos'):
"""Calculate the global similarity of traffic flow data.
:param flow_data: tensor, original flow [n,l,v,c] or location embedding [n,v,c]
:param type: str, type of similarity, attention or cosine. ['att', 'cos']
:return sim: tensor, symmetric similarity, [v,v]
"""
if len(flow_data.shape) == 4:
n,l,v,c = flow_data.shape
att_scaling = n * l * c
cos_scaling = torch.norm(flow_data, p=2, dim=(0, 1, 3)) ** -1 # cal 2-norm of each node, dim N
sim = torch.einsum('btnc, btmc->nm', flow_data, flow_data)
elif len(flow_data.shape) == 3:
n,v,c = flow_data.shape
att_scaling = n * c
cos_scaling = torch.norm(flow_data, p=2, dim=(0, 2)) ** -1 # cal 2-norm of each node, dim N
sim = torch.einsum('bnc, bmc->nm', flow_data, flow_data)
else:
raise ValueError('sim_global only support shape length in [3, 4] but got {}.'.format(len(flow_data.shape)))
if sim_type == 'cos':
# cosine similarity
scaling = torch.einsum('i, j->ij', cos_scaling, cos_scaling)
sim = sim * scaling
elif sim_type == 'att':
# scaled dot product similarity
scaling = float(att_scaling) ** -0.5
sim = torch.softmax(sim * scaling, dim=-1)
else:
raise ValueError('sim_global only support sim_type in [att, cos].')
return sim
def aug_topology(sim_mx, input_graph, percent=0.2):
"""Generate the data augumentation from topology (graph structure) perspective
for undirected graph without self-loop.
:param sim_mx: tensor, symmetric similarity, [v,v]
:param input_graph: tensor, adjacency matrix without self-loop, [v,v]
:return aug_graph: tensor, augmented adjacency matrix on cuda, [v,v]
"""
## edge dropping starts here
drop_percent = percent / 2
index_list = input_graph.nonzero() # list of edges [row_idx, col_idx]
edge_num = int(index_list.shape[0] / 2) # treat one undirected edge as two edges
edge_mask = (input_graph > 0).tril(diagonal=-1)
add_drop_num = int(edge_num * drop_percent / 2)
aug_graph = copy.deepcopy(input_graph)
drop_prob = torch.softmax(sim_mx[edge_mask], dim=0)
drop_prob = (1. - drop_prob).numpy() # normalized similarity to get sampling probability
drop_prob /= drop_prob.sum()
drop_list = np.random.choice(edge_num, size=add_drop_num, p=drop_prob)
drop_index = index_list[drop_list]
zeros = torch.zeros_like(aug_graph[0, 0])
aug_graph[drop_index[:, 0], drop_index[:, 1]] = zeros
aug_graph[drop_index[:, 1], drop_index[:, 0]] = zeros
## edge adding starts here
node_num = input_graph.shape[0]
x, y = np.meshgrid(range(node_num), range(node_num), indexing='ij')
mask = y < x
x, y = x[mask], y[mask]
add_prob = sim_mx[torch.ones(sim_mx.size(), dtype=bool).tril(diagonal=-1)] # .numpy()
add_prob = torch.softmax(add_prob, dim=0).numpy()
add_list = np.random.choice(int((node_num * node_num - node_num) / 2),
size=add_drop_num, p=add_prob)
ones = torch.ones_like(aug_graph[0, 0])
aug_graph[x[add_list], y[add_list]] = ones
aug_graph[y[add_list], x[add_list]] = ones
return aug_graph
def aug_traffic(t_sim_mx, flow_data, percent=0.2):
"""Generate the data augumentation from traffic (node attribute) perspective.
:param t_sim_mx: temporal similarity matrix after softmax, [l,n,v]
:param flow_data: input flow data, [n,l,v,c]
"""
l, n, v = t_sim_mx.shape
mask_num = int(n * l * v * percent)
aug_flow = copy.deepcopy(flow_data)
mask_prob = (1. - t_sim_mx.permute(1, 0, 2).reshape(-1)).numpy()
mask_prob /= mask_prob.sum()
x, y, z = np.meshgrid(range(n), range(l), range(v), indexing='ij')
mask_list = np.random.choice(n * l * v, size=mask_num, p=mask_prob)
zeros = torch.zeros_like(aug_flow[0, 0, 0])
aug_flow[
x.reshape(-1)[mask_list],
y.reshape(-1)[mask_list],
z.reshape(-1)[mask_list]] = zeros
return aug_flow