178 lines
6.3 KiB
Python
178 lines
6.3 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import print_function
|
|
from __future__ import division
|
|
|
|
import torch
|
|
import numpy as np
|
|
import scipy.sparse as sp
|
|
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch_geometric.data import Data
|
|
|
|
from federatedscope.gfl.model import SAGE_Net
|
|
"""
|
|
https://proceedings.neurips.cc//paper/2021/file/ \
|
|
34adeb8e3242824038aa65460a47c29e-Paper.pdf
|
|
Fedsageplus models from the "Subgraph Federated Learning with Missing
|
|
Neighbor Generation" (FedSage+) paper, in NeurIPS'21
|
|
Source: https://github.com/zkhku/fedsage
|
|
"""
|
|
|
|
|
|
class Sampling(nn.Module):
|
|
def __init__(self):
|
|
super(Sampling, self).__init__()
|
|
|
|
def forward(self, inputs):
|
|
rand = torch.normal(0, 1, size=inputs.shape)
|
|
|
|
return inputs + rand.to(inputs.device)
|
|
|
|
|
|
class FeatGenerator(nn.Module):
|
|
def __init__(self, latent_dim, dropout, num_pred, feat_shape):
|
|
super(FeatGenerator, self).__init__()
|
|
self.num_pred = num_pred
|
|
self.feat_shape = feat_shape
|
|
self.dropout = dropout
|
|
self.sample = Sampling()
|
|
self.fc1 = nn.Linear(latent_dim, 256)
|
|
self.fc2 = nn.Linear(256, 2048)
|
|
self.fc_flat = nn.Linear(2048, self.num_pred * self.feat_shape)
|
|
|
|
def forward(self, x):
|
|
x = self.sample(x)
|
|
x = F.relu(self.fc1(x))
|
|
x = F.relu(self.fc2(x))
|
|
x = F.dropout(x, self.dropout, training=self.training)
|
|
x = torch.tanh(self.fc_flat(x))
|
|
|
|
return x
|
|
|
|
|
|
class NumPredictor(nn.Module):
|
|
def __init__(self, latent_dim):
|
|
self.latent_dim = latent_dim
|
|
super(NumPredictor, self).__init__()
|
|
self.reg_1 = nn.Linear(self.latent_dim, 1)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.reg_1(x))
|
|
return x
|
|
|
|
|
|
# Mend the graph via NeighGen
|
|
class MendGraph(nn.Module):
|
|
def __init__(self, num_pred):
|
|
super(MendGraph, self).__init__()
|
|
self.num_pred = num_pred
|
|
for param in self.parameters():
|
|
param.requires_grad = False
|
|
|
|
def mend_graph(self, x, edge_index, pred_degree, gen_feats):
|
|
device = gen_feats.device
|
|
num_node, num_feature = x.shape
|
|
new_edges = []
|
|
gen_feats = gen_feats.view(-1, self.num_pred, num_feature)
|
|
|
|
if pred_degree.device.type != 'cpu':
|
|
pred_degree = pred_degree.cpu()
|
|
pred_degree = torch._cast_Int(torch.round(pred_degree)).detach()
|
|
x = x.detach()
|
|
fill_feats = torch.vstack((x, gen_feats.view(-1, num_feature)))
|
|
|
|
for i in range(num_node):
|
|
for j in range(min(self.num_pred, max(0, pred_degree[i]))):
|
|
new_edges.append(
|
|
np.asarray([i, num_node + i * self.num_pred + j]))
|
|
|
|
new_edges = torch.tensor(np.asarray(new_edges).reshape((-1, 2)),
|
|
dtype=torch.int64).T
|
|
new_edges = new_edges.to(device)
|
|
if len(new_edges) > 0:
|
|
fill_edges = torch.hstack((edge_index, new_edges))
|
|
else:
|
|
fill_edges = torch.clone(edge_index)
|
|
return fill_feats, fill_edges
|
|
|
|
def forward(self, x, edge_index, pred_missing, gen_feats):
|
|
fill_feats, fill_edges = self.mend_graph(x, edge_index, pred_missing,
|
|
gen_feats)
|
|
|
|
return fill_feats, fill_edges
|
|
|
|
|
|
class LocalSage_Plus(nn.Module):
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
hidden,
|
|
gen_hidden,
|
|
dropout=0.5,
|
|
num_pred=5):
|
|
super(LocalSage_Plus, self).__init__()
|
|
|
|
self.encoder_model = SAGE_Net(in_channels=in_channels,
|
|
out_channels=gen_hidden,
|
|
hidden=hidden,
|
|
max_depth=2,
|
|
dropout=dropout)
|
|
self.reg_model = NumPredictor(latent_dim=gen_hidden)
|
|
self.gen = FeatGenerator(latent_dim=gen_hidden,
|
|
dropout=dropout,
|
|
num_pred=num_pred,
|
|
feat_shape=in_channels)
|
|
self.mend_graph = MendGraph(num_pred)
|
|
|
|
self.classifier = SAGE_Net(in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
hidden=hidden,
|
|
max_depth=2,
|
|
dropout=dropout)
|
|
|
|
def forward(self, data):
|
|
x = self.encoder_model(data)
|
|
degree = self.reg_model(x)
|
|
gen_feat = self.gen(x)
|
|
mend_feats, mend_edge_index = self.mend_graph(data.x, data.edge_index,
|
|
degree, gen_feat)
|
|
nc_pred = self.classifier(
|
|
Data(x=mend_feats, edge_index=mend_edge_index))
|
|
return degree, gen_feat, nc_pred[:data.num_nodes]
|
|
|
|
def inference(self, impared_data, raw_data):
|
|
x = self.encoder_model(impared_data)
|
|
degree = self.reg_model(x)
|
|
gen_feat = self.gen(x)
|
|
mend_feats, mend_edge_index = self.mend_graph(raw_data.x,
|
|
raw_data.edge_index,
|
|
degree, gen_feat)
|
|
nc_pred = self.classifier(
|
|
Data(x=mend_feats, edge_index=mend_edge_index))
|
|
return degree, gen_feat, nc_pred[:raw_data.num_nodes]
|
|
|
|
|
|
class FedSage_Plus(nn.Module):
|
|
def __init__(self, local_graph: LocalSage_Plus):
|
|
super(FedSage_Plus, self).__init__()
|
|
self.encoder_model = local_graph.encoder_model
|
|
self.reg_model = local_graph.reg_model
|
|
self.gen = local_graph.gen
|
|
self.mend_graph = local_graph.mend_graph
|
|
self.classifier = local_graph.classifier
|
|
self.encoder_model.requires_grad_(False)
|
|
self.reg_model.requires_grad_(False)
|
|
self.mend_graph.requires_grad_(False)
|
|
self.classifier.requires_grad_(False)
|
|
|
|
def forward(self, data):
|
|
x = self.encoder_model(data)
|
|
degree = self.reg_model(x)
|
|
gen_feat = self.gen(x)
|
|
mend_feats, mend_edge_index = self.mend_graph(data.x, data.edge_index,
|
|
degree, gen_feat)
|
|
nc_pred = self.classifier(
|
|
Data(x=mend_feats, edge_index=mend_edge_index))
|
|
return degree, gen_feat, nc_pred[:data.num_nodes]
|