118 lines
3.6 KiB
Python
118 lines
3.6 KiB
Python
import logging
|
|
|
|
from federatedscope.mf.trainer.trainer import MFTrainer
|
|
from federatedscope.core.trainers.utils import get_random
|
|
from typing import Type
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Modifications:
|
|
#
|
|
# 1. Fix issue where embed_user.shape is deprecated, embed_user.weight.shape instead. Line 91
|
|
# It may be casued by torch.Embedding update.
|
|
# (2024-9-8, czzhangheng)
|
|
|
|
def wrap_MFTrainer(base_trainer: Type[MFTrainer]) -> Type[MFTrainer]:
|
|
"""Build `SGDMFTrainer` with a plug-in manner, by registering new
|
|
functions into specific `MFTrainer`
|
|
|
|
"""
|
|
|
|
# ---------------- attribute-level plug-in -----------------------
|
|
init_sgdmf_ctx(base_trainer)
|
|
|
|
# ---------------- action-level plug-in -----------------------
|
|
base_trainer.replace_hook_in_train(
|
|
new_hook=hook_on_batch_backward,
|
|
target_trigger="on_batch_backward",
|
|
target_hook_name="_hook_on_batch_backward")
|
|
|
|
return base_trainer
|
|
|
|
|
|
def init_sgdmf_ctx(base_trainer):
|
|
"""Init necessary attributes used in SGDMF,
|
|
some new attributes will be with prefix `SGDMF` optimizer to avoid
|
|
namespace pollution
|
|
|
|
"""
|
|
ctx = base_trainer.ctx
|
|
cfg = base_trainer.cfg
|
|
|
|
sample_ratio = float(cfg.dataloader.batch_size) / cfg.model.num_user
|
|
# Noise multiplier
|
|
tmp = cfg.sgdmf.constant * np.power(sample_ratio, 2) * (
|
|
cfg.federate.total_round_num * ctx.num_total_train_batch) * np.log(
|
|
1. / cfg.sgdmf.delta)
|
|
noise_multipler = np.sqrt(tmp / np.power(cfg.sgdmf.epsilon, 2))
|
|
ctx.scale = max(cfg.dataloader.theta, 1.) * noise_multipler * np.power(
|
|
cfg.sgdmf.R, 1.5)
|
|
logger.info("Inject noise: (loc=0, scale={})".format(ctx.scale))
|
|
ctx.sgdmf_R = cfg.sgdmf.R
|
|
|
|
|
|
def embedding_clip(param, R: int):
|
|
"""Clip embedding vector according to $R$
|
|
|
|
Arguments:
|
|
param (tensor): The embedding vector
|
|
R (int): The upper bound of ratings
|
|
"""
|
|
# Turn all negative entries of U into 0
|
|
param.data = (torch.abs(param.data) + param.data) * 0.5
|
|
# Clip tensor
|
|
norms = torch.linalg.norm(param.data, dim=1)
|
|
threshold = np.sqrt(R)
|
|
param.data[norms > threshold] *= (threshold /
|
|
norms[norms > threshold]).reshape(
|
|
(-1, 1))
|
|
param.data[param.data < 0] = 0.
|
|
|
|
|
|
def hook_on_batch_backward(ctx):
|
|
"""Private local updates in SGDMF
|
|
|
|
"""
|
|
ctx.optimizer.zero_grad()
|
|
ctx.loss_task.backward()
|
|
|
|
if ctx.model.embed_user.weight.grad.is_sparse:
|
|
dense_user_grad = ctx.model.embed_user.weight.grad.to_dense()
|
|
else:
|
|
dense_user_grad = ctx.model.embed_user.weight.grad
|
|
|
|
if ctx.model.embed_item.weight.grad.is_sparse:
|
|
dense_item_grad = ctx.model.embed_item.weight.grad.to_dense()
|
|
else:
|
|
dense_item_grad = ctx.model.embed_item.weight.grad
|
|
|
|
# Inject noise
|
|
dense_user_grad.data += get_random(
|
|
"Normal",
|
|
sample_shape=ctx.model.embed_user.weight.shape,
|
|
params={
|
|
"loc": 0,
|
|
"scale": ctx.scale
|
|
},
|
|
device=ctx.model.embed_user.weight.device)
|
|
dense_item_grad.data += get_random(
|
|
"Normal",
|
|
sample_shape=ctx.model.embed_item.weight.shape,
|
|
params={
|
|
"loc": 0,
|
|
"scale": ctx.scale
|
|
},
|
|
device=ctx.model.embed_item.weight.device)
|
|
|
|
ctx.model.embed_user.weight.grad = dense_user_grad.to_sparse()
|
|
ctx.model.embed_item.weight.grad = dense_item_grad.to_sparse()
|
|
ctx.optimizer.step()
|
|
|
|
# Embedding clipping
|
|
with torch.no_grad():
|
|
embedding_clip(ctx.model.embed_user.weight, ctx.sgdmf_R)
|
|
embedding_clip(ctx.model.embed_item.weight, ctx.sgdmf_R)
|