FS-TFP/federatedscope/mf/trainer/trainer_sgdmf.py

118 lines
3.6 KiB
Python

import logging
from federatedscope.mf.trainer.trainer import MFTrainer
from federatedscope.core.trainers.utils import get_random
from typing import Type
import numpy as np
import torch
logger = logging.getLogger(__name__)
# Modifications:
#
# 1. Fix issue where embed_user.shape is deprecated, embed_user.weight.shape instead. Line 91
# It may be casued by torch.Embedding update.
# (2024-9-8, czzhangheng)
def wrap_MFTrainer(base_trainer: Type[MFTrainer]) -> Type[MFTrainer]:
"""Build `SGDMFTrainer` with a plug-in manner, by registering new
functions into specific `MFTrainer`
"""
# ---------------- attribute-level plug-in -----------------------
init_sgdmf_ctx(base_trainer)
# ---------------- action-level plug-in -----------------------
base_trainer.replace_hook_in_train(
new_hook=hook_on_batch_backward,
target_trigger="on_batch_backward",
target_hook_name="_hook_on_batch_backward")
return base_trainer
def init_sgdmf_ctx(base_trainer):
"""Init necessary attributes used in SGDMF,
some new attributes will be with prefix `SGDMF` optimizer to avoid
namespace pollution
"""
ctx = base_trainer.ctx
cfg = base_trainer.cfg
sample_ratio = float(cfg.dataloader.batch_size) / cfg.model.num_user
# Noise multiplier
tmp = cfg.sgdmf.constant * np.power(sample_ratio, 2) * (
cfg.federate.total_round_num * ctx.num_total_train_batch) * np.log(
1. / cfg.sgdmf.delta)
noise_multipler = np.sqrt(tmp / np.power(cfg.sgdmf.epsilon, 2))
ctx.scale = max(cfg.dataloader.theta, 1.) * noise_multipler * np.power(
cfg.sgdmf.R, 1.5)
logger.info("Inject noise: (loc=0, scale={})".format(ctx.scale))
ctx.sgdmf_R = cfg.sgdmf.R
def embedding_clip(param, R: int):
"""Clip embedding vector according to $R$
Arguments:
param (tensor): The embedding vector
R (int): The upper bound of ratings
"""
# Turn all negative entries of U into 0
param.data = (torch.abs(param.data) + param.data) * 0.5
# Clip tensor
norms = torch.linalg.norm(param.data, dim=1)
threshold = np.sqrt(R)
param.data[norms > threshold] *= (threshold /
norms[norms > threshold]).reshape(
(-1, 1))
param.data[param.data < 0] = 0.
def hook_on_batch_backward(ctx):
"""Private local updates in SGDMF
"""
ctx.optimizer.zero_grad()
ctx.loss_task.backward()
if ctx.model.embed_user.weight.grad.is_sparse:
dense_user_grad = ctx.model.embed_user.weight.grad.to_dense()
else:
dense_user_grad = ctx.model.embed_user.weight.grad
if ctx.model.embed_item.weight.grad.is_sparse:
dense_item_grad = ctx.model.embed_item.weight.grad.to_dense()
else:
dense_item_grad = ctx.model.embed_item.weight.grad
# Inject noise
dense_user_grad.data += get_random(
"Normal",
sample_shape=ctx.model.embed_user.weight.shape,
params={
"loc": 0,
"scale": ctx.scale
},
device=ctx.model.embed_user.weight.device)
dense_item_grad.data += get_random(
"Normal",
sample_shape=ctx.model.embed_item.weight.shape,
params={
"loc": 0,
"scale": ctx.scale
},
device=ctx.model.embed_item.weight.device)
ctx.model.embed_user.weight.grad = dense_user_grad.to_sparse()
ctx.model.embed_item.weight.grad = dense_item_grad.to_sparse()
ctx.optimizer.step()
# Embedding clipping
with torch.no_grad():
embedding_clip(ctx.model.embed_user.weight, ctx.sgdmf_R)
embedding_clip(ctx.model.embed_item.weight, ctx.sgdmf_R)