164 lines
5.2 KiB
Python
164 lines
5.2 KiB
Python
from scipy.sparse import csc_matrix
|
|
from scipy.sparse import coo_matrix
|
|
from numpy.random import shuffle
|
|
|
|
import numpy as np
|
|
|
|
import collections
|
|
import importlib
|
|
|
|
MFDATA_CLASS_DICT = {
|
|
"vflmovielens1m": "VFLMovieLens1M",
|
|
"vflmovielens10m": "VFLMovieLens10M",
|
|
"hflmovielens1m": "HFLMovieLens1M",
|
|
"hflmovielens10m": "HFLMovieLens10M",
|
|
'vflnetflix': "VFLNetflix",
|
|
'hflnetflix': "HFLNetflix"
|
|
}
|
|
|
|
|
|
def load_mf_dataset(config=None, client_cfgs=None):
|
|
"""Return the dataset of matrix factorization
|
|
|
|
Format:
|
|
{
|
|
'client_id': {
|
|
'train': DataLoader(),
|
|
'test': DataLoader(),
|
|
'val': DataLoader()
|
|
}
|
|
}
|
|
|
|
"""
|
|
if config.data.type.lower() in MFDATA_CLASS_DICT:
|
|
# Dataset
|
|
if config.data.type.lower() in ['vflnetflix', 'hflnetflix']:
|
|
mpath = "federatedscope.mf.dataset.netflix"
|
|
else:
|
|
mpath = "federatedscope.mf.dataset.movielens"
|
|
dataset = getattr(importlib.import_module(mpath),
|
|
MFDATA_CLASS_DICT[config.data.type.lower()])(
|
|
root=config.data.root,
|
|
num_client=config.federate.client_num,
|
|
train_portion=config.data.splits[0],
|
|
download=True)
|
|
else:
|
|
raise NotImplementedError("Dataset {} is not implemented.".format(
|
|
config.data.type))
|
|
|
|
data_dict = collections.defaultdict(dict)
|
|
for client_idx, data in dataset.data.items():
|
|
data_dict[client_idx] = data
|
|
|
|
# Modify config
|
|
config.merge_from_list(['model.num_user', dataset.n_user])
|
|
config.merge_from_list(['model.num_item', dataset.n_item])
|
|
|
|
return data_dict, config
|
|
|
|
|
|
class MFDataLoader(object):
|
|
"""DataLoader for MF dataset
|
|
|
|
Args:
|
|
data (csc_matrix): sparse MF dataset
|
|
batch_size (int): the size of batch data
|
|
shuffle (bool): shuffle the dataset
|
|
drop_last (bool): drop the last batch if True
|
|
theta (int): the maximal number of ratings for each user
|
|
"""
|
|
def __init__(self,
|
|
data: csc_matrix,
|
|
batch_size: int,
|
|
shuffle=True,
|
|
drop_last=False,
|
|
theta=None):
|
|
super(MFDataLoader, self).__init__()
|
|
self.dataset = self._trim_data(data, theta)
|
|
self.shuffle = shuffle
|
|
self.batch_size = batch_size
|
|
self.drop_last = drop_last
|
|
|
|
self.n_row = self.dataset.shape[0]
|
|
self.n_col = self.dataset.shape[1]
|
|
self.n_rating = self.dataset.count_nonzero()
|
|
|
|
self._idx_samples = None
|
|
self._idx_cur = None
|
|
|
|
self._reset()
|
|
|
|
def _trim_data(self, data, theta=None):
|
|
"""Trim rating data by parameter theta (per-user privacy)
|
|
|
|
Arguments:
|
|
data (csc_matrix): the dataset
|
|
theta (int): The maximal number of ratings for each user
|
|
"""
|
|
if theta is None or theta <= 0:
|
|
return data
|
|
else:
|
|
# Each user has at most $theta$ items
|
|
dataset = data.tocoo()
|
|
user2items = collections.defaultdict(list)
|
|
for idx, user_id in enumerate(dataset.row):
|
|
user2items[user_id].append(idx)
|
|
# sample theta each
|
|
idx_select = list()
|
|
for items in user2items.values():
|
|
if len(items) > theta:
|
|
idx_select += np.random.choice(items, theta,
|
|
replace=False).tolist()
|
|
else:
|
|
idx_select += items
|
|
dataset = coo_matrix(
|
|
(dataset.data[idx_select],
|
|
(dataset.row[idx_select], dataset.col[idx_select])),
|
|
shape=dataset.shape).tocsc()
|
|
return dataset
|
|
|
|
def _reset(self):
|
|
self._idx_cur = 0
|
|
if self._idx_samples is None:
|
|
self._idx_samples = np.arange(self.n_rating)
|
|
if self.shuffle:
|
|
shuffle(self._idx_samples)
|
|
|
|
def _sample_data(self, sampled_rating_idx):
|
|
dataset = self.dataset.tocoo()
|
|
data = dataset.data[sampled_rating_idx]
|
|
rows = dataset.row[sampled_rating_idx]
|
|
cols = dataset.col[sampled_rating_idx]
|
|
return (rows, cols), data
|
|
|
|
def __len__(self):
|
|
"""The number of batches within an epoch
|
|
|
|
"""
|
|
if self.drop_last:
|
|
return int(self.n_rating / self.batch_size)
|
|
else:
|
|
return int(self.n_rating / self.batch_size) + int(
|
|
(self.n_rating % self.batch_size) != 0)
|
|
|
|
def __next__(self, theta=None):
|
|
"""Get the next batch of data
|
|
|
|
Args:
|
|
theta (int): the maximal number of ratings for each user
|
|
"""
|
|
idx_end = self._idx_cur + self.batch_size
|
|
if self._idx_cur == len(
|
|
self._idx_samples) or self.drop_last and idx_end > len(
|
|
self._idx_samples):
|
|
raise StopIteration
|
|
idx_end = min(idx_end, len(self._idx_samples))
|
|
idx_choice_samples = self._idx_samples[self._idx_cur:idx_end]
|
|
self._idx_cur = idx_end
|
|
|
|
return self._sample_data(idx_choice_samples)
|
|
|
|
def __iter__(self):
|
|
self._reset()
|
|
return self
|