195 lines
6.9 KiB
Python
195 lines
6.9 KiB
Python
"""The function partition_by_category and subgraphing are borrowed from
|
|
https://github.com/FedML-AI/FedGraphNN
|
|
|
|
Copyright [FedML] [Chaoyang He, Salman Avestimehr]
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import os
|
|
import os.path as osp
|
|
import networkx as nx
|
|
|
|
import torch
|
|
from torch_geometric.data import InMemoryDataset, download_url, Data
|
|
from torch_geometric.utils import from_networkx
|
|
|
|
|
|
# RecSys
|
|
def read_mapping(path, filename):
|
|
mapping = {}
|
|
with open(os.path.join(path, filename)) as f:
|
|
for line in f:
|
|
s = line.strip().split()
|
|
mapping[int(s[0])] = int(s[1])
|
|
|
|
return mapping
|
|
|
|
|
|
def partition_by_category(graph, mapping_item2category):
|
|
partition = {}
|
|
for key in mapping_item2category:
|
|
partition[key] = [mapping_item2category[key]]
|
|
for neighbor in graph.neighbors(key):
|
|
if neighbor not in partition:
|
|
partition[neighbor] = []
|
|
partition[neighbor].append(mapping_item2category[key])
|
|
return partition
|
|
|
|
|
|
def subgraphing(g, partion, mapping_item2category):
|
|
nodelist = [[] for i in set(mapping_item2category.keys())]
|
|
for k, v in partion.items():
|
|
for category in v:
|
|
nodelist[category].append(k)
|
|
|
|
graphs = []
|
|
for nodes in nodelist:
|
|
if len(nodes) < 2:
|
|
continue
|
|
graph = nx.subgraph(g, nodes)
|
|
graphs.append(from_networkx(graph))
|
|
return graphs
|
|
|
|
|
|
def read_RecSys(path, FL=False):
|
|
mapping_user = read_mapping(path, 'user.dict')
|
|
mapping_item = read_mapping(path, 'item.dict')
|
|
|
|
G = nx.Graph()
|
|
with open(osp.join(path, 'graph.txt')) as f:
|
|
for line in f:
|
|
s = line.strip().split()
|
|
s = [int(i) for i in s]
|
|
G.add_edge(mapping_user[s[0]], mapping_item[s[1]], edge_type=s[2])
|
|
dic = {}
|
|
for node in G.nodes:
|
|
dic[node] = node
|
|
nx.set_node_attributes(G, dic, "index_orig")
|
|
H = nx.Graph()
|
|
H.add_nodes_from(sorted(G.nodes(data=True)))
|
|
H.add_edges_from(G.edges(data=True))
|
|
G = H
|
|
if FL:
|
|
mapping_item2category = read_mapping(path, "category.dict")
|
|
partition = partition_by_category(G, mapping_item2category)
|
|
graphs = subgraphing(G, partition, mapping_item2category)
|
|
return graphs
|
|
else:
|
|
return [from_networkx(G)]
|
|
|
|
|
|
class RecSys(InMemoryDataset):
|
|
r"""
|
|
Arguments:
|
|
root (string): Root directory where the dataset should be saved.
|
|
name (string): The name of the dataset (:obj:`"epinions"`,
|
|
:obj:`"ciao"`).
|
|
FL (Bool): Federated setting or centralized setting.
|
|
transform (callable, optional): A function/transform that takes in an
|
|
:obj:`torch_geometric.data.Data` object and returns a transformed
|
|
version. The data object will be transformed before every access.
|
|
(default: :obj:`None`)
|
|
pre_transform (callable, optional): A function/transform that takes in
|
|
an :obj:`torch_geometric.data.Data` object and returns a
|
|
transformed version. The data object will be transformed before
|
|
being saved to disk. (default: :obj:`None`)
|
|
"""
|
|
def __init__(self,
|
|
root,
|
|
name,
|
|
FL=False,
|
|
splits=[0.8, 0.1, 0.1],
|
|
transform=None,
|
|
pre_transform=None):
|
|
self.FL = FL
|
|
if self.FL:
|
|
self.name = 'FL' + name
|
|
else:
|
|
self.name = name
|
|
self._customized_splits = splits
|
|
super().__init__(root, transform, pre_transform)
|
|
self.data, self.slices = torch.load(self.processed_paths[0])
|
|
|
|
@property
|
|
def raw_file_names(self):
|
|
names = ['user.dict', 'item.dict', 'category.dict', 'graph.txt']
|
|
return names
|
|
|
|
@property
|
|
def processed_file_names(self):
|
|
return ['data.pt']
|
|
|
|
@property
|
|
def raw_dir(self):
|
|
return osp.join(self.root, self.name, 'raw')
|
|
|
|
@property
|
|
def processed_dir(self):
|
|
return osp.join(self.root, self.name, 'processed')
|
|
|
|
def download(self):
|
|
"""
|
|
Download raw files to `self.raw_dir` from FedGraphNN.
|
|
Paper: https://arxiv.org/abs/2104.07145
|
|
Repo: https://github.com/FedML-AI/FedGraphNN
|
|
"""
|
|
url = 'https://raw.githubusercontent.com/FedML-AI/FedGraphNN' \
|
|
'/82912342950e0cd1be2b683e48ef8bfd5cb0a276/data' \
|
|
'/recommender_system/'
|
|
if self.name.startswith('FL'):
|
|
suffix = self.name[2:]
|
|
else:
|
|
suffix = self.name
|
|
url = osp.join(url, suffix)
|
|
for name in self.raw_file_names:
|
|
download_url(f'{url}/{name}', self.raw_dir)
|
|
|
|
def process(self):
|
|
# Read data into huge `Data` list.
|
|
data_list = read_RecSys(self.raw_dir, self.FL)
|
|
|
|
data_list_w_masks = []
|
|
for data in data_list:
|
|
if self.name.endswith('epinions'):
|
|
data.edge_type = data.edge_type - 1
|
|
if data.num_edges == 0:
|
|
continue
|
|
indices = torch.randperm(data.num_edges)
|
|
data.train_edge_mask = torch.zeros(data.num_edges,
|
|
dtype=torch.bool)
|
|
data.train_edge_mask[indices[:round(self._customized_splits[0] *
|
|
data.num_edges)]] = True
|
|
data.valid_edge_mask = torch.zeros(data.num_edges,
|
|
dtype=torch.bool)
|
|
data.valid_edge_mask[indices[
|
|
round(self._customized_splits[0] *
|
|
data.num_edges):round((self._customized_splits[0] +
|
|
self._customized_splits[1]) *
|
|
data.num_edges)]] = True
|
|
data.test_edge_mask = torch.zeros(data.num_edges, dtype=torch.bool)
|
|
data.test_edge_mask[indices[round((self._customized_splits[0] +
|
|
self._customized_splits[1]) *
|
|
data.num_edges):]] = True
|
|
data_list_w_masks.append(data)
|
|
data_list = data_list_w_masks
|
|
|
|
if self.pre_filter is not None:
|
|
data_list = [data for data in data_list if self.pre_filter(data)]
|
|
|
|
if self.pre_transform is not None:
|
|
data_list = [self.pre_transform(data) for data in data_list]
|
|
|
|
data, slices = self.collate(data_list)
|
|
torch.save((data, slices), self.processed_paths[0])
|