import logging from federatedscope.mf.trainer.trainer import MFTrainer from federatedscope.core.trainers.utils import get_random from typing import Type import numpy as np import torch logger = logging.getLogger(__name__) # Modifications: # # 1. Fix issue where embed_user.shape is deprecated, embed_user.weight.shape instead. Line 91 # It may be casued by torch.Embedding update. # (2024-9-8, czzhangheng) def wrap_MFTrainer(base_trainer: Type[MFTrainer]) -> Type[MFTrainer]: """Build `SGDMFTrainer` with a plug-in manner, by registering new functions into specific `MFTrainer` """ # ---------------- attribute-level plug-in ----------------------- init_sgdmf_ctx(base_trainer) # ---------------- action-level plug-in ----------------------- base_trainer.replace_hook_in_train( new_hook=hook_on_batch_backward, target_trigger="on_batch_backward", target_hook_name="_hook_on_batch_backward") return base_trainer def init_sgdmf_ctx(base_trainer): """Init necessary attributes used in SGDMF, some new attributes will be with prefix `SGDMF` optimizer to avoid namespace pollution """ ctx = base_trainer.ctx cfg = base_trainer.cfg sample_ratio = float(cfg.dataloader.batch_size) / cfg.model.num_user # Noise multiplier tmp = cfg.sgdmf.constant * np.power(sample_ratio, 2) * ( cfg.federate.total_round_num * ctx.num_total_train_batch) * np.log( 1. / cfg.sgdmf.delta) noise_multipler = np.sqrt(tmp / np.power(cfg.sgdmf.epsilon, 2)) ctx.scale = max(cfg.dataloader.theta, 1.) * noise_multipler * np.power( cfg.sgdmf.R, 1.5) logger.info("Inject noise: (loc=0, scale={})".format(ctx.scale)) ctx.sgdmf_R = cfg.sgdmf.R def embedding_clip(param, R: int): """Clip embedding vector according to $R$ Arguments: param (tensor): The embedding vector R (int): The upper bound of ratings """ # Turn all negative entries of U into 0 param.data = (torch.abs(param.data) + param.data) * 0.5 # Clip tensor norms = torch.linalg.norm(param.data, dim=1) threshold = np.sqrt(R) param.data[norms > threshold] *= (threshold / norms[norms > threshold]).reshape( (-1, 1)) param.data[param.data < 0] = 0. def hook_on_batch_backward(ctx): """Private local updates in SGDMF """ ctx.optimizer.zero_grad() ctx.loss_task.backward() if ctx.model.embed_user.weight.grad.is_sparse: dense_user_grad = ctx.model.embed_user.weight.grad.to_dense() else: dense_user_grad = ctx.model.embed_user.weight.grad if ctx.model.embed_item.weight.grad.is_sparse: dense_item_grad = ctx.model.embed_item.weight.grad.to_dense() else: dense_item_grad = ctx.model.embed_item.weight.grad # Inject noise dense_user_grad.data += get_random( "Normal", sample_shape=ctx.model.embed_user.weight.shape, params={ "loc": 0, "scale": ctx.scale }, device=ctx.model.embed_user.weight.device) dense_item_grad.data += get_random( "Normal", sample_shape=ctx.model.embed_item.weight.shape, params={ "loc": 0, "scale": ctx.scale }, device=ctx.model.embed_item.weight.device) ctx.model.embed_user.weight.grad = dense_user_grad.to_sparse() ctx.model.embed_item.weight.grad = dense_item_grad.to_sparse() ctx.optimizer.step() # Embedding clipping with torch.no_grad(): embedding_clip(ctx.model.embed_user.weight, ctx.sgdmf_R) embedding_clip(ctx.model.embed_item.weight, ctx.sgdmf_R)