60 lines
1.7 KiB
Python
60 lines
1.7 KiB
Python
from torch.nn import Module
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
class BasicMFNet(Module):
|
|
"""Basic model for MF task
|
|
|
|
Arguments:
|
|
num_user (int): the number of users
|
|
num_item (int): the number of items
|
|
num_hidden (int): the dimension of embedding vector
|
|
"""
|
|
def __init__(self, num_user, num_item, num_hidden):
|
|
super(BasicMFNet, self).__init__()
|
|
self.num_user, self.num_item = num_user, num_item
|
|
self.embed_user = torch.nn.Embedding(num_user, num_hidden, sparse=True)
|
|
self.embed_item = torch.nn.Embedding(num_item, num_hidden, sparse=True)
|
|
|
|
def forward(self, indices, ratings):
|
|
device = self.embed_user.weight.device
|
|
|
|
indices = torch.tensor(np.array(indices)).to(device)
|
|
|
|
user_embedding = self.embed_user(indices[0])
|
|
item_embedding = self.embed_item(indices[1])
|
|
|
|
pred = (user_embedding * item_embedding).sum(dim=1)
|
|
|
|
label = torch.tensor(np.array(ratings)).to(device)
|
|
|
|
return pred, label
|
|
|
|
def load_state_dict(self, state_dict, strict: bool = True):
|
|
|
|
state_dict[self.name_reserve + '.weight'] = getattr(
|
|
getattr(self, self.name_reserve), 'weight')
|
|
super().load_state_dict(state_dict, strict)
|
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
|
state_dict = super().state_dict(destination, prefix, keep_vars)
|
|
# Mask embed_item
|
|
del state_dict[self.name_reserve + '.weight']
|
|
return state_dict
|
|
|
|
|
|
class VMFNet(BasicMFNet):
|
|
"""MF model for vertical federated learning
|
|
|
|
"""
|
|
name_reserve = "embed_item"
|
|
|
|
|
|
class HMFNet(BasicMFNet):
|
|
"""MF model for horizontal federated learning
|
|
|
|
"""
|
|
name_reserve = "embed_user"
|