FS-TFP/federatedscope/gfl/fedsageplus/trainer.py

149 lines
6.3 KiB
Python

import torch
import copy
import numpy as np
import torch.nn.functional as F
from federatedscope.gfl.loss import GreedyLoss
from federatedscope.gfl.trainer.nodetrainer import NodeFullBatchTrainer
class LocalGenTrainer(NodeFullBatchTrainer):
def __init__(self,
model,
data,
device,
config,
only_for_eval=False,
monitor=None):
super(LocalGenTrainer, self).__init__(model, data, device, config,
only_for_eval, monitor)
self.criterion_num = F.smooth_l1_loss
self.criterion_feat = GreedyLoss
def _hook_on_batch_forward(self, ctx):
batch = ctx.data_batch.to(ctx.device)
mask = batch['{}_mask'.format(ctx.cur_mode)]
pred_missing, pred_feat, nc_pred = ctx.model(batch)
pred_missing, pred_feat, nc_pred = pred_missing[mask], pred_feat[
mask], nc_pred[mask]
loss_num = self.criterion_num(pred_missing, batch.num_missing[mask])
loss_feat = self.criterion_feat(
pred_feats=pred_feat,
true_feats=batch.x_missing[mask],
pred_missing=pred_missing,
true_missing=batch.num_missing[mask],
num_pred=self.cfg.fedsageplus.num_pred).requires_grad_()
loss_clf = ctx.criterion(nc_pred, batch.y[mask])
ctx.batch_size = torch.sum(mask).item()
ctx.loss_batch = (self.cfg.fedsageplus.a * loss_num +
self.cfg.fedsageplus.b * loss_feat +
self.cfg.fedsageplus.c * loss_clf).float()
ctx.y_true = batch.num_missing[mask]
ctx.y_prob = pred_missing
class FedGenTrainer(LocalGenTrainer):
def _hook_on_batch_forward(self, ctx):
batch = ctx.data_batch.to(ctx.device)
mask = batch['{}_mask'.format(ctx.cur_mode)]
pred_missing, pred_feat, nc_pred = ctx.model(batch)
pred_missing, pred_feat, nc_pred = pred_missing[mask], pred_feat[
mask], nc_pred[mask]
loss_num = self.criterion_num(pred_missing, batch.num_missing[mask])
loss_feat = self.criterion_feat(pred_feats=pred_feat,
true_feats=batch.x_missing[mask],
pred_missing=pred_missing,
true_missing=batch.num_missing[mask],
num_pred=self.cfg.fedsageplus.num_pred)
loss_clf = ctx.criterion(nc_pred, batch.y[mask])
ctx.batch_size = torch.sum(mask).item()
ctx.loss_batch = (self.cfg.fedsageplus.a * loss_num +
self.cfg.fedsageplus.b * loss_feat +
self.cfg.fedsageplus.c *
loss_clf).float() / self.cfg.federate.client_num
ctx.y_true = batch.num_missing[mask]
ctx.y_prob = pred_missing
def update_by_grad(self, grads):
"""
Arguments:
grads: grads of other clients to optimize the local model
:returns:
state_dict of generation model
"""
for key in grads.keys():
if isinstance(grads[key], list):
grads[key] = torch.FloatTensor(grads[key]).to(self.ctx.device)
for key, value in self.ctx.model.named_parameters():
value.grad += grads[key]
self.ctx.optimizer.step()
return self.ctx.model.cpu().state_dict()
def cal_grad(self, raw_data, model_para, embedding, true_missing):
"""
Arguments:
raw_data (Pyg.Data): raw graph
model_para: model parameters
embedding: output embeddings after local encoder
true_missing: number of missing node
:returns:
grads: grads to optimize the model of other clients
"""
para_backup = copy.deepcopy(self.ctx.model.cpu().state_dict())
for key in model_para.keys():
if isinstance(model_para[key], list):
model_para[key] = torch.FloatTensor(model_para[key])
self.ctx.model.load_state_dict(model_para)
self.ctx.model = self.ctx.model.to(self.ctx.device)
self.ctx.model.train()
raw_data = raw_data.to(self.ctx.device)
embedding = torch.FloatTensor(embedding).to(self.ctx.device)
true_missing = true_missing.long().to(self.ctx.device)
pred_missing = self.ctx.model.reg_model(embedding)
pred_feat = self.ctx.model.gen(embedding)
# Random pick node and compare its neighbors with predicted nodes
choice = np.random.choice(raw_data.num_nodes, embedding.shape[0])
global_target_feat = []
for c_i in choice:
neighbors_ids = raw_data.edge_index[1][torch.where(
raw_data.edge_index[0] == c_i)[0]]
while len(neighbors_ids) == 0:
id_i = np.random.choice(raw_data.num_nodes, 1)[0]
neighbors_ids = raw_data.edge_index[1][torch.where(
raw_data.edge_index[0] == id_i)[0]]
choice_i = np.random.choice(neighbors_ids.detach().cpu().numpy(),
self.cfg.fedsageplus.num_pred)
for ch_i in choice_i:
global_target_feat.append(
raw_data.x[ch_i].detach().cpu().numpy())
global_target_feat = np.asarray(global_target_feat).reshape(
(embedding.shape[0], self.cfg.fedsageplus.num_pred,
raw_data.num_node_features))
loss_feat = self.criterion_feat(pred_feats=pred_feat,
true_feats=global_target_feat,
pred_missing=pred_missing,
true_missing=true_missing,
num_pred=self.cfg.fedsageplus.num_pred)
loss = self.cfg.fedsageplus.b * loss_feat
loss = (1.0 / self.cfg.federate.client_num * loss).requires_grad_()
loss.backward()
grads = {
key: value.grad
for key, value in self.ctx.model.named_parameters()
}
# Rollback
self.ctx.model.load_state_dict(para_backup)
return grads
@torch.no_grad()
def embedding(self):
model = self.ctx.model.to(self.ctx.device)
data = self.ctx.data['data'].to(self.ctx.device)
return model.encoder_model(data).to('cpu')