modifications on original FS
modifications on original FS
This commit is contained in:
parent
c9acf63692
commit
6ea133716f
|
|
@ -9,6 +9,10 @@ import federatedscope.register as register
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Modifications:
|
||||||
|
# 1. Add FedDGCN support. Line 203
|
||||||
|
# (2024-10-8, czzhangheng)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from federatedscope.contrib.data import *
|
from federatedscope.contrib.data import *
|
||||||
except ImportError as error:
|
except ImportError as error:
|
||||||
|
|
@ -23,13 +27,14 @@ TRANS_DATA_MAP = {
|
||||||
'.*?@.*?', 'hiv', 'proteins', 'imdb-binary', 'bbbp', 'tox21', 'bace',
|
'.*?@.*?', 'hiv', 'proteins', 'imdb-binary', 'bbbp', 'tox21', 'bace',
|
||||||
'sider', 'clintox', 'esol', 'freesolv', 'lipo', 'cifar4cl', 'cifar4lp'
|
'sider', 'clintox', 'esol', 'freesolv', 'lipo', 'cifar4cl', 'cifar4lp'
|
||||||
],
|
],
|
||||||
|
# Add trafficflow to 'DummyDataTranslator'
|
||||||
'DummyDataTranslator': [
|
'DummyDataTranslator': [
|
||||||
'toy', 'quadratic', 'femnist', 'celeba', 'shakespeare', 'twitter',
|
'toy', 'quadratic', 'femnist', 'celeba', 'shakespeare', 'twitter',
|
||||||
'subreddit', 'synthetic', 'ciao', 'epinions', '.*?vertical_fl_data.*?',
|
'subreddit', 'synthetic', 'ciao', 'epinions', '.*?vertical_fl_data.*?',
|
||||||
'.*?movielens.*?', '.*?netflix.*?', '.*?cikmcup.*?',
|
'.*?movielens.*?', '.*?netflix.*?', '.*?cikmcup.*?',
|
||||||
'graph_multi_domain.*?', 'cora', 'citeseer', 'pubmed', 'dblp_conf',
|
'graph_multi_domain.*?', 'cora', 'citeseer', 'pubmed', 'dblp_conf',
|
||||||
'dblp_org', 'csbm.*?', 'fb15k-237', 'wn18', 'adult', 'abalone',
|
'dblp_org', 'csbm.*?', 'fb15k-237', 'wn18', 'adult', 'abalone',
|
||||||
'credit', 'blog'
|
'credit', 'blog', 'trafficflow'
|
||||||
], # Dummy for FL dataset
|
], # Dummy for FL dataset
|
||||||
'RawDataTranslator': ['hetero_nlp_tasks'],
|
'RawDataTranslator': ['hetero_nlp_tasks'],
|
||||||
}
|
}
|
||||||
|
|
@ -111,6 +116,7 @@ def get_data(config, client_cfgs=None):
|
||||||
HFLMovieLens10M Recommendation
|
HFLMovieLens10M Recommendation
|
||||||
VFLNetflix Recommendation
|
VFLNetflix Recommendation
|
||||||
HFLNetflix Recommendation
|
HFLNetflix Recommendation
|
||||||
|
trafficflow Traffic Flow Prediction
|
||||||
================================== ===========================
|
================================== ===========================
|
||||||
"""
|
"""
|
||||||
# Fix the seed for data generation
|
# Fix the seed for data generation
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,10 @@ import federatedscope.register as register
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Modifications:
|
||||||
|
# 1. Do a my_gcn demo Line 75
|
||||||
|
# (2024-9-1, czzhangheng)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from federatedscope.contrib.model import *
|
from federatedscope.contrib.model import *
|
||||||
except ImportError as error:
|
except ImportError as error:
|
||||||
|
|
@ -72,7 +76,10 @@ def get_shape_from_data(data, model_config, backend='torch'):
|
||||||
elif backend == 'torch':
|
elif backend == 'torch':
|
||||||
import torch
|
import torch
|
||||||
if issubclass(type(data_representative), torch.utils.data.DataLoader):
|
if issubclass(type(data_representative), torch.utils.data.DataLoader):
|
||||||
x, _ = next(iter(data_representative))
|
if model_config.type == 'my_gcn':
|
||||||
|
x = next(iter(data_representative))
|
||||||
|
return x.x.shape
|
||||||
|
x = next(iter(data_representative))
|
||||||
if isinstance(x, list):
|
if isinstance(x, list):
|
||||||
return x[0].shape
|
return x[0].shape
|
||||||
return x.shape
|
return x.shape
|
||||||
|
|
@ -197,6 +204,9 @@ def get_model(model_config, local_data=None, backend='torch'):
|
||||||
elif model_config.type.lower() in ['atc_model']:
|
elif model_config.type.lower() in ['atc_model']:
|
||||||
from federatedscope.nlp.hetero_tasks.model import ATCModel
|
from federatedscope.nlp.hetero_tasks.model import ATCModel
|
||||||
model = ATCModel(model_config)
|
model = ATCModel(model_config)
|
||||||
|
elif model_config.type.lower() in ['feddgcn']:
|
||||||
|
from federatedscope.trafficflow.model.FedDGCN import FedDGCN
|
||||||
|
model = FedDGCN(model_config)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Model {} is not provided'.format(model_config.type))
|
raise ValueError('Model {} is not provided'.format(model_config.type))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -75,6 +75,10 @@ def get_splitter(config):
|
||||||
elif config.data.splitter == 'iid':
|
elif config.data.splitter == 'iid':
|
||||||
from federatedscope.core.splitters.generic import IIDSplitter
|
from federatedscope.core.splitters.generic import IIDSplitter
|
||||||
splitter = IIDSplitter(client_num)
|
splitter = IIDSplitter(client_num)
|
||||||
|
# Add traffic flow splitter
|
||||||
|
elif config.data.splitter == 'trafficflowprediction':
|
||||||
|
from federatedscope.trafficflow.splitters.trafficSplitter import TrafficSplitter
|
||||||
|
splitter = TrafficSplitter(client_num)
|
||||||
else:
|
else:
|
||||||
logger.warning(f'Splitter {config.data.splitter} not found or not '
|
logger.warning(f'Splitter {config.data.splitter} not found or not '
|
||||||
f'used.')
|
f'used.')
|
||||||
|
|
|
||||||
|
|
@ -176,6 +176,14 @@ def get_trainer(model=None,
|
||||||
data=data,
|
data=data,
|
||||||
device=device,
|
device=device,
|
||||||
monitor=monitor)
|
monitor=monitor)
|
||||||
|
# Add traffic flow trainer
|
||||||
|
elif config.trainer.type.lower() in ['trafficflowtrainer']:
|
||||||
|
from federatedscope.trafficflow.trainer.trafficflow_trainer import call_trafficflow_trainer
|
||||||
|
trainer = call_trafficflow_trainer(config=config,
|
||||||
|
model=model,
|
||||||
|
data=data,
|
||||||
|
device=device,
|
||||||
|
monitor=monitor)
|
||||||
else:
|
else:
|
||||||
# try to find user registered trainer
|
# try to find user registered trainer
|
||||||
trainer = None
|
trainer = None
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,25 @@ def extend_data_cfg(cfg):
|
||||||
cfg.data.num_contrast = 0
|
cfg.data.num_contrast = 0
|
||||||
cfg.data.is_debug = False
|
cfg.data.is_debug = False
|
||||||
|
|
||||||
|
# Traffic Flow data parameters, These are only default values.
|
||||||
|
# Please modify the specific parameters directly in the YAML files.
|
||||||
|
cfg.data.root = 'data/trafficflow/PeMS04'
|
||||||
|
cfg.data.type = 'trafficflow'
|
||||||
|
cfg.data.num_nodes = 307
|
||||||
|
cfg.data.lag = 12
|
||||||
|
cfg.data.horizon = 12
|
||||||
|
cfg.data.val_ratio = 0.2
|
||||||
|
cfg.data.test_ratio = 0.2
|
||||||
|
cfg.data.tod = False
|
||||||
|
cfg.data.normalizer = 'std'
|
||||||
|
cfg.data.column_wise = False
|
||||||
|
cfg.data.default_graph = True
|
||||||
|
cfg.data.add_time_in_day = True
|
||||||
|
cfg.data.add_day_in_week = True
|
||||||
|
cfg.data.steps_per_day = 288
|
||||||
|
cfg.data.days_per_week = 7
|
||||||
|
cfg.data.scaler = [0,0]
|
||||||
|
|
||||||
# feature engineering
|
# feature engineering
|
||||||
cfg.feat_engr = CN()
|
cfg.feat_engr = CN()
|
||||||
cfg.feat_engr.type = ''
|
cfg.feat_engr.type = ''
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,21 @@ def extend_model_cfg(cfg):
|
||||||
cfg.model.contrast_topk = 100
|
cfg.model.contrast_topk = 100
|
||||||
cfg.model.contrast_temp = 1.0
|
cfg.model.contrast_temp = 1.0
|
||||||
|
|
||||||
|
# Traffic Flow model parameters, These are only default values.
|
||||||
|
# Please modify the specific parameters directly in the baselines/YAML files.
|
||||||
|
cfg.model.num_nodes = 0
|
||||||
|
cfg.model.rnn_units = 64
|
||||||
|
cfg.model.dropout = 0.1
|
||||||
|
cfg.model.horizon = 12
|
||||||
|
cfg.model.input_dim = 1 # If 0, model will be built by data.shape
|
||||||
|
cfg.model.output_dim = 1
|
||||||
|
cfg.model.embed_dim = 10
|
||||||
|
cfg.model.num_layers = 1 # In GPR-GNN, K = layer
|
||||||
|
cfg.model.cheb_order = 1 # A tuple, e.g., (in_channel, h, w)
|
||||||
|
cfg.model.use_day = True
|
||||||
|
cfg.model.use_week = True
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
# Criterion related options
|
# Criterion related options
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
from federatedscope.core.configs.config import CN
|
||||||
|
from federatedscope.register import register_config
|
||||||
|
|
||||||
|
"""
|
||||||
|
The parameter settings for traffic flow prediction are located in the YAML files under
|
||||||
|
the baseline folder within the trafficflow package. These are only default values.
|
||||||
|
Please modify the specific parameters directly in the YAML files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def extend_trafficflow_cfg(cfg):
|
||||||
|
# ---------------------------------------------------------------------- #
|
||||||
|
# Model related options
|
||||||
|
# ---------------------------------------------------------------------- #
|
||||||
|
cfg.model = CN()
|
||||||
|
|
||||||
|
cfg.model.model_num_per_trainer = 1 # some methods may leverage more
|
||||||
|
# than one model in each trainer
|
||||||
|
cfg.model.type = 'trafficflow'
|
||||||
|
cfg.model.use_bias = True
|
||||||
|
cfg.model.task = 'trafficflowprediction'
|
||||||
|
cfg.model.num_nodes = 0
|
||||||
|
cfg.model.rnn_units = 64
|
||||||
|
cfg.model.dropout = 0.1
|
||||||
|
cfg.model.horizon = 12
|
||||||
|
cfg.model.input_dim = 1 # If 0, model will be built by data.shape
|
||||||
|
cfg.model.output_dim = 1
|
||||||
|
cfg.model.embed_dim = 10
|
||||||
|
cfg.model.num_layers = 1 # In GPR-GNN, K = layer
|
||||||
|
cfg.model.cheb_order = 1 # A tuple, e.g., (in_channel, h, w)
|
||||||
|
cfg.model.use_day = True
|
||||||
|
cfg.model.use_week = True
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------- #
|
||||||
|
# Criterion related options
|
||||||
|
# ---------------------------------------------------------------------- #
|
||||||
|
cfg.criterion = CN()
|
||||||
|
|
||||||
|
cfg.criterion.type = 'L1Loss'
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------- #
|
||||||
|
# regularizer related options
|
||||||
|
# ---------------------------------------------------------------------- #
|
||||||
|
cfg.regularizer = CN()
|
||||||
|
|
||||||
|
cfg.regularizer.type = ''
|
||||||
|
cfg.regularizer.mu = 0.
|
||||||
|
|
||||||
|
# --------------- register corresponding check function ----------
|
||||||
|
cfg.register_cfg_check_fun(assert_model_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_model_cfg(cfg):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
register_config("trafficflow", extend_trafficflow_cfg)
|
||||||
|
|
@ -24,6 +24,7 @@ def extend_training_cfg(cfg):
|
||||||
# atc (TODO: merge later)
|
# atc (TODO: merge later)
|
||||||
cfg.trainer.disp_freq = 50
|
cfg.trainer.disp_freq = 50
|
||||||
cfg.trainer.val_freq = 100000000 # eval freq across batches
|
cfg.trainer.val_freq = 100000000 # eval freq across batches
|
||||||
|
cfg.trainer.log_dir = ''
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
# Training related options
|
# Training related options
|
||||||
|
|
@ -38,6 +39,23 @@ def extend_training_cfg(cfg):
|
||||||
cfg.train.optimizer.type = 'SGD'
|
cfg.train.optimizer.type = 'SGD'
|
||||||
cfg.train.optimizer.lr = 0.1
|
cfg.train.optimizer.lr = 0.1
|
||||||
|
|
||||||
|
# trafficflow
|
||||||
|
cfg.train.loss_func = 'mae'
|
||||||
|
cfg.train.seed = 10
|
||||||
|
cfg.train.batch_size = 64
|
||||||
|
cfg.train.epochs = 300
|
||||||
|
cfg.train.lr_init = 0.003
|
||||||
|
cfg.train.weight_decay = 0
|
||||||
|
cfg.train.lr_decay = False
|
||||||
|
cfg.train.lr_decay_rate = 0.3
|
||||||
|
cfg.train.lr_decay_step = [5, 20, 40, 70]
|
||||||
|
cfg.train.early_stop = True
|
||||||
|
cfg.train.early_stop_patience = 15
|
||||||
|
cfg.train.grad_norm = False
|
||||||
|
cfg.train.max_grad_norm = 5
|
||||||
|
cfg.train.real_value = True
|
||||||
|
|
||||||
|
|
||||||
# you can add new arguments 'aa' by `cfg.train.scheduler.aa = 'bb'`
|
# you can add new arguments 'aa' by `cfg.train.scheduler.aa = 'bb'`
|
||||||
cfg.train.scheduler = CN(new_allowed=True)
|
cfg.train.scheduler = CN(new_allowed=True)
|
||||||
cfg.train.scheduler.type = ''
|
cfg.train.scheduler.type = ''
|
||||||
|
|
@ -91,6 +109,9 @@ def extend_training_cfg(cfg):
|
||||||
# Early stop when no improve to last `patience` round, in ['mean', 'best']
|
# Early stop when no improve to last `patience` round, in ['mean', 'best']
|
||||||
cfg.early_stop.improve_indicator_mode = 'best'
|
cfg.early_stop.improve_indicator_mode = 'best'
|
||||||
|
|
||||||
|
# TODO:trafficflow
|
||||||
|
|
||||||
|
|
||||||
# --------------- register corresponding check function ----------
|
# --------------- register corresponding check function ----------
|
||||||
cfg.register_cfg_check_fun(assert_training_cfg)
|
cfg.register_cfg_check_fun(assert_training_cfg)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,20 @@
|
||||||
import copy
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import os.path as osp
|
||||||
import re
|
import re
|
||||||
import ssl
|
import ssl
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
from collections import defaultdict
|
||||||
|
from random import shuffle
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os.path as osp
|
|
||||||
|
|
||||||
from random import shuffle
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Modification Record: This file has been reformatted.
|
||||||
|
"""
|
||||||
|
|
||||||
class RegexInverseMap:
|
class RegexInverseMap:
|
||||||
def __init__(self, n_dic, val):
|
def __init__(self, n_dic, val):
|
||||||
|
|
@ -57,16 +58,16 @@ def load_dataset(config, client_cfgs=None):
|
||||||
from federatedscope.cl.dataloader import load_cifar_dataset
|
from federatedscope.cl.dataloader import load_cifar_dataset
|
||||||
dataset, modified_config = load_cifar_dataset(config)
|
dataset, modified_config = load_cifar_dataset(config)
|
||||||
elif config.data.type.lower() in [
|
elif config.data.type.lower() in [
|
||||||
'shakespeare', 'twitter', 'subreddit', 'synthetic'
|
'shakespeare', 'twitter', 'subreddit', 'synthetic'
|
||||||
]:
|
]:
|
||||||
from federatedscope.nlp.dataloader import load_nlp_dataset
|
from federatedscope.nlp.dataloader import load_nlp_dataset
|
||||||
dataset, modified_config = load_nlp_dataset(config)
|
dataset, modified_config = load_nlp_dataset(config)
|
||||||
elif config.data.type.lower() in [
|
elif config.data.type.lower() in [
|
||||||
'cora',
|
'cora',
|
||||||
'citeseer',
|
'citeseer',
|
||||||
'pubmed',
|
'pubmed',
|
||||||
'dblp_conf',
|
'dblp_conf',
|
||||||
'dblp_org',
|
'dblp_org',
|
||||||
] or config.data.type.lower().startswith('csbm'):
|
] or config.data.type.lower().startswith('csbm'):
|
||||||
from federatedscope.gfl.dataloader import load_nodelevel_dataset
|
from federatedscope.gfl.dataloader import load_nodelevel_dataset
|
||||||
dataset, modified_config = load_nodelevel_dataset(config)
|
dataset, modified_config = load_nodelevel_dataset(config)
|
||||||
|
|
@ -74,13 +75,13 @@ def load_dataset(config, client_cfgs=None):
|
||||||
from federatedscope.gfl.dataloader import load_linklevel_dataset
|
from federatedscope.gfl.dataloader import load_linklevel_dataset
|
||||||
dataset, modified_config = load_linklevel_dataset(config)
|
dataset, modified_config = load_linklevel_dataset(config)
|
||||||
elif config.data.type.lower() in [
|
elif config.data.type.lower() in [
|
||||||
'hiv', 'proteins', 'imdb-binary', 'bbbp', 'tox21', 'bace', 'sider',
|
'hiv', 'proteins', 'imdb-binary', 'bbbp', 'tox21', 'bace', 'sider',
|
||||||
'clintox', 'esol', 'freesolv', 'lipo', 'cikmcup'
|
'clintox', 'esol', 'freesolv', 'lipo', 'cikmcup'
|
||||||
] or config.data.type.startswith('graph_multi_domain'):
|
] or config.data.type.startswith('graph_multi_domain'):
|
||||||
from federatedscope.gfl.dataloader import load_graphlevel_dataset
|
from federatedscope.gfl.dataloader import load_graphlevel_dataset
|
||||||
dataset, modified_config = load_graphlevel_dataset(config)
|
dataset, modified_config = load_graphlevel_dataset(config)
|
||||||
elif config.data.type.lower() in [
|
elif config.data.type.lower() in [
|
||||||
'synthetic_vfl_data', 'adult', 'abalone', 'credit', 'blog'
|
'synthetic_vfl_data', 'adult', 'abalone', 'credit', 'blog'
|
||||||
]:
|
]:
|
||||||
from federatedscope.vertical_fl.dataloader import load_vertical_data
|
from federatedscope.vertical_fl.dataloader import load_vertical_data
|
||||||
generate = config.data.type.lower() == 'synthetic_vfl_data'
|
generate = config.data.type.lower() == 'synthetic_vfl_data'
|
||||||
|
|
@ -97,10 +98,17 @@ def load_dataset(config, client_cfgs=None):
|
||||||
elif '@' in config.data.type.lower():
|
elif '@' in config.data.type.lower():
|
||||||
from federatedscope.core.data.utils import load_external_data
|
from federatedscope.core.data.utils import load_external_data
|
||||||
dataset, modified_config = load_external_data(config)
|
dataset, modified_config = load_external_data(config)
|
||||||
|
elif 'cora' in config.data.type.lower():
|
||||||
|
from federatedscope.contrib.data.my_cora import call_my_data
|
||||||
|
dataset, modified_config = call_my_data(config, client_cfgs)
|
||||||
elif config.data.type is None or config.data.type == "":
|
elif config.data.type is None or config.data.type == "":
|
||||||
# The participant (only for server in this version) does not own data
|
# The participant (only for server in this version) does not own data
|
||||||
dataset = None
|
dataset = None
|
||||||
modified_config = config
|
modified_config = config
|
||||||
|
elif config.data.type.lower() in [
|
||||||
|
'trafficflow']:
|
||||||
|
from federatedscope.trafficflow.dataloader.traffic_dataloader import load_traffic_data
|
||||||
|
dataset, modified_config = load_traffic_data(config, client_cfgs)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Dataset {} not found.'.format(config.data.type))
|
raise ValueError('Dataset {} not found.'.format(config.data.type))
|
||||||
return dataset, modified_config
|
return dataset, modified_config
|
||||||
|
|
@ -305,8 +313,8 @@ def load_external_data(config=None):
|
||||||
config.data.transform.append({})
|
config.data.transform.append({})
|
||||||
vocab = getattr(import_module('torchtext.vocab'),
|
vocab = getattr(import_module('torchtext.vocab'),
|
||||||
config.data.transform[0])(
|
config.data.transform[0])(
|
||||||
dim=config.model.in_channels,
|
dim=config.model.in_channels,
|
||||||
**config.data.transform[1])
|
**config.data.transform[1])
|
||||||
|
|
||||||
if 'classification' in config.model.task.lower():
|
if 'classification' in config.model.task.lower():
|
||||||
data = [
|
data = [
|
||||||
|
|
@ -360,7 +368,7 @@ def load_external_data(config=None):
|
||||||
lengths = [train_size, val_size]
|
lengths = [train_size, val_size]
|
||||||
data_split_dict['train'], data_split_dict[
|
data_split_dict['train'], data_split_dict[
|
||||||
'val'] = torch.utils.data.dataset.random_split(
|
'val'] = torch.utils.data.dataset.random_split(
|
||||||
data_split_dict['train'], lengths)
|
data_split_dict['train'], lengths)
|
||||||
else:
|
else:
|
||||||
# Use config.data.splits
|
# Use config.data.splits
|
||||||
data_split_dict = {}
|
data_split_dict = {}
|
||||||
|
|
@ -370,7 +378,7 @@ def load_external_data(config=None):
|
||||||
lengths = [train_size, val_size, test_size]
|
lengths = [train_size, val_size, test_size]
|
||||||
data_split_dict['train'], data_split_dict['val'], data_split_dict[
|
data_split_dict['train'], data_split_dict['val'], data_split_dict[
|
||||||
'test'] = torch.utils.data.dataset.random_split(
|
'test'] = torch.utils.data.dataset.random_split(
|
||||||
data_list[0], lengths)
|
data_list[0], lengths)
|
||||||
|
|
||||||
return data_split_dict
|
return data_split_dict
|
||||||
|
|
||||||
|
|
@ -458,7 +466,7 @@ def load_external_data(config=None):
|
||||||
original_train_size = len(data_split_dict["train"])
|
original_train_size = len(data_split_dict["train"])
|
||||||
|
|
||||||
if "half_val_dummy_test" in raw_args and raw_args[
|
if "half_val_dummy_test" in raw_args and raw_args[
|
||||||
"half_val_dummy_test"]:
|
"half_val_dummy_test"]:
|
||||||
# since the "test" set from GLUE dataset may be masked, we need to
|
# since the "test" set from GLUE dataset may be masked, we need to
|
||||||
# submit to get the ground-truth, for fast FL experiments,
|
# submit to get the ground-truth, for fast FL experiments,
|
||||||
# we split the validation set into two parts with the same size as
|
# we split the validation set into two parts with the same size as
|
||||||
|
|
@ -467,22 +475,22 @@ def load_external_data(config=None):
|
||||||
dataset['validation'][1])]
|
dataset['validation'][1])]
|
||||||
data_split_dict["val"], data_split_dict[
|
data_split_dict["val"], data_split_dict[
|
||||||
"test"] = original_val[:len(original_val) //
|
"test"] = original_val[:len(original_val) //
|
||||||
2], original_val[len(original_val) //
|
2], original_val[len(original_val) //
|
||||||
2:]
|
2:]
|
||||||
if "val_as_dummy_test" in raw_args and raw_args["val_as_dummy_test"]:
|
if "val_as_dummy_test" in raw_args and raw_args["val_as_dummy_test"]:
|
||||||
# use the validation set as tmp test set,
|
# use the validation set as tmp test set,
|
||||||
# and partial training set as validation set
|
# and partial training set as validation set
|
||||||
data_split_dict["test"] = data_split_dict["val"]
|
data_split_dict["test"] = data_split_dict["val"]
|
||||||
data_split_dict["val"] = []
|
data_split_dict["val"] = []
|
||||||
if "part_train_dummy_val" in raw_args and 1 > raw_args[
|
if "part_train_dummy_val" in raw_args and 1 > raw_args[
|
||||||
"part_train_dummy_val"] > 0:
|
"part_train_dummy_val"] > 0:
|
||||||
new_val_part = int(original_train_size *
|
new_val_part = int(original_train_size *
|
||||||
raw_args["part_train_dummy_val"])
|
raw_args["part_train_dummy_val"])
|
||||||
data_split_dict["val"].extend(
|
data_split_dict["val"].extend(
|
||||||
data_split_dict["train"][:new_val_part])
|
data_split_dict["train"][:new_val_part])
|
||||||
data_split_dict["train"] = data_split_dict["train"][new_val_part:]
|
data_split_dict["train"] = data_split_dict["train"][new_val_part:]
|
||||||
if "part_train_dummy_test" in raw_args and 1 > raw_args[
|
if "part_train_dummy_test" in raw_args and 1 > raw_args[
|
||||||
"part_train_dummy_test"] > 0:
|
"part_train_dummy_test"] > 0:
|
||||||
new_test_part = int(original_train_size *
|
new_test_part = int(original_train_size *
|
||||||
raw_args["part_train_dummy_test"])
|
raw_args["part_train_dummy_test"])
|
||||||
data_split_dict["test"] = data_split_dict["val"]
|
data_split_dict["test"] = data_split_dict["val"]
|
||||||
|
|
|
||||||
|
|
@ -140,8 +140,11 @@ class MetricCalculator(object):
|
||||||
if torch is not None and isinstance(y_prob, torch.Tensor):
|
if torch is not None and isinstance(y_prob, torch.Tensor):
|
||||||
y_prob = y_prob.detach().cpu().numpy()
|
y_prob = y_prob.detach().cpu().numpy()
|
||||||
|
|
||||||
|
# Add traffic flow metrics
|
||||||
if 'regression' in ctx.cfg.model.task.lower():
|
if 'regression' in ctx.cfg.model.task.lower():
|
||||||
y_pred = None
|
y_pred = None
|
||||||
|
elif 'trafficflowprediction' in ctx.cfg.model.task.lower():
|
||||||
|
y_pred = None
|
||||||
else:
|
else:
|
||||||
# classification task
|
# classification task
|
||||||
if y_true.ndim == 1:
|
if y_true.ndim == 1:
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,9 @@ from federatedscope.core.monitors.monitor import Monitor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Modifications:
|
||||||
|
# 1. Fixed some errors in the flop_counter.
|
||||||
|
# (2024-10-06, Heng-Zhang)
|
||||||
|
|
||||||
class GeneralTorchTrainer(Trainer):
|
class GeneralTorchTrainer(Trainer):
|
||||||
def get_model_para(self):
|
def get_model_para(self):
|
||||||
|
|
@ -161,8 +164,7 @@ class GeneralTorchTrainer(Trainer):
|
||||||
================================== ===========================
|
================================== ===========================
|
||||||
Attribute Operation
|
Attribute Operation
|
||||||
================================== ===========================
|
================================== ===========================
|
||||||
``ctx.model`` Wrap ``nn.Module` to \
|
``ctx.model`` Wrap ``nn.Module` to `nn.DataParallel`
|
||||||
`nn.DataParallel`
|
|
||||||
================================== ===========================
|
================================== ===========================
|
||||||
"""
|
"""
|
||||||
if isinstance(ctx.model, torch.nn.DataParallel):
|
if isinstance(ctx.model, torch.nn.DataParallel):
|
||||||
|
|
@ -325,7 +327,10 @@ class GeneralTorchTrainer(Trainer):
|
||||||
try:
|
try:
|
||||||
x, y = [_.to(ctx.device) for _ in ctx.data_batch]
|
x, y = [_.to(ctx.device) for _ in ctx.data_batch]
|
||||||
from fvcore.nn import FlopCountAnalysis
|
from fvcore.nn import FlopCountAnalysis
|
||||||
flops_one_batch = FlopCountAnalysis(ctx.model, x).total()
|
# Something wrong!!
|
||||||
|
flop_counter = FlopCountAnalysis(ctx.model, x)
|
||||||
|
flop_counter.unsupported_ops_warnings(False)
|
||||||
|
flops_one_batch = flop_counter.total()
|
||||||
if self.model_nums > 1 and ctx.mirrored_models:
|
if self.model_nums > 1 and ctx.mirrored_models:
|
||||||
flops_one_batch *= self.model_nums
|
flops_one_batch *= self.model_nums
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,10 @@ from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
|
||||||
from federatedscope.core.trainers.trainer_multi_model import \
|
from federatedscope.core.trainers.trainer_multi_model import \
|
||||||
GeneralMultiModelTrainer
|
GeneralMultiModelTrainer
|
||||||
|
|
||||||
|
# Modifications:
|
||||||
|
# 1. Fix issue where the trainer cannot access the monitor. Line 31
|
||||||
|
# (2024-10-6, czzhangheng)
|
||||||
|
|
||||||
|
|
||||||
class FedEMTrainer(GeneralMultiModelTrainer):
|
class FedEMTrainer(GeneralMultiModelTrainer):
|
||||||
"""
|
"""
|
||||||
|
|
@ -25,10 +29,11 @@ class FedEMTrainer(GeneralMultiModelTrainer):
|
||||||
data=None,
|
data=None,
|
||||||
device=None,
|
device=None,
|
||||||
config=None,
|
config=None,
|
||||||
|
monitor=None,
|
||||||
base_trainer: Type[GeneralTorchTrainer] = None):
|
base_trainer: Type[GeneralTorchTrainer] = None):
|
||||||
super(FedEMTrainer,
|
super(FedEMTrainer,
|
||||||
self).__init__(model_nums, models_interact_mode, model, data,
|
self).__init__(model_nums, models_interact_mode, model, data,
|
||||||
device, config, base_trainer)
|
device, config, monitor=monitor)
|
||||||
device = self.ctx.device
|
device = self.ctx.device
|
||||||
|
|
||||||
# --------------- attribute-level modifications ----------------------
|
# --------------- attribute-level modifications ----------------------
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,10 @@ from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
# Modifications:
|
||||||
|
# 1. Fix issue where the trainer cannot access the monitor. Line 68
|
||||||
|
# 2. Fix issue where deepcopy cannot copy items Line 77
|
||||||
|
# (2024-10-6, czzhangheng)
|
||||||
|
|
||||||
class GeneralMultiModelTrainer(GeneralTorchTrainer):
|
class GeneralMultiModelTrainer(GeneralTorchTrainer):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|
@ -16,6 +20,7 @@ class GeneralMultiModelTrainer(GeneralTorchTrainer):
|
||||||
data=None,
|
data=None,
|
||||||
device=None,
|
device=None,
|
||||||
config=None,
|
config=None,
|
||||||
|
monitor=None,
|
||||||
base_trainer: Type[GeneralTorchTrainer] = None):
|
base_trainer: Type[GeneralTorchTrainer] = None):
|
||||||
"""
|
"""
|
||||||
`GeneralMultiModelTrainer` supports train/eval via multiple
|
`GeneralMultiModelTrainer` supports train/eval via multiple
|
||||||
|
|
@ -65,7 +70,7 @@ class GeneralMultiModelTrainer(GeneralTorchTrainer):
|
||||||
"data, device, config) should not be " \
|
"data, device, config) should not be " \
|
||||||
"None"
|
"None"
|
||||||
super(GeneralMultiModelTrainer,
|
super(GeneralMultiModelTrainer,
|
||||||
self).__init__(model, data, device, config)
|
self).__init__(model, data, device, config, monitor=monitor)
|
||||||
else:
|
else:
|
||||||
assert isinstance(base_trainer, GeneralMultiModelTrainer) or \
|
assert isinstance(base_trainer, GeneralMultiModelTrainer) or \
|
||||||
issubclass(type(base_trainer), GeneralMultiModelTrainer) \
|
issubclass(type(base_trainer), GeneralMultiModelTrainer) \
|
||||||
|
|
@ -74,7 +79,13 @@ class GeneralMultiModelTrainer(GeneralTorchTrainer):
|
||||||
"can only copy instances of `GeneralMultiModelTrainer` " \
|
"can only copy instances of `GeneralMultiModelTrainer` " \
|
||||||
"and its subclasses, or " \
|
"and its subclasses, or " \
|
||||||
"`GeneralTorchTrainer` and its subclasses"
|
"`GeneralTorchTrainer` and its subclasses"
|
||||||
self.__dict__ = copy.deepcopy(base_trainer.__dict__)
|
# self.__dict__ = copy.deepcopy(base_trainer.__dict__)
|
||||||
|
# 逐个复制 base_trainer 的属性,跳过不可拷贝的对象
|
||||||
|
for key, value in base_trainer.__dict__.items():
|
||||||
|
try:
|
||||||
|
self.__dict__[key] = copy.deepcopy(value)
|
||||||
|
except TypeError:
|
||||||
|
self.__dict__[key] = value # 如果不能 deepcopy,使用浅拷贝
|
||||||
|
|
||||||
assert models_interact_mode in ["sequential", "parallel"], \
|
assert models_interact_mode in ["sequential", "parallel"], \
|
||||||
f"Invalid models_interact_mode, should be `sequential` or " \
|
f"Invalid models_interact_mode, should be `sequential` or " \
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,9 @@ import numpy as np
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Modifications:
|
||||||
|
# 1. Fix issue where iteritems() is deprecated, items() instead. Line 160
|
||||||
|
# (2024-10-8, czzhangheng)
|
||||||
|
|
||||||
class VMFDataset:
|
class VMFDataset:
|
||||||
"""Dataset of matrix factorization task in vertical federated learning.
|
"""Dataset of matrix factorization task in vertical federated learning.
|
||||||
|
|
@ -157,8 +160,8 @@ class MovieLensData(object):
|
||||||
}, {mid: idx
|
}, {mid: idx
|
||||||
for idx, mid in enumerate(unique_id_user)}
|
for idx, mid in enumerate(unique_id_user)}
|
||||||
|
|
||||||
row = [mapping_user[mid] for _, mid in data["userId"].iteritems()]
|
row = [mapping_user[mid] for _, mid in data["userId"].items()]
|
||||||
col = [mapping_item[mid] for _, mid in data["movieId"].iteritems()]
|
col = [mapping_item[mid] for _, mid in data["movieId"].items()]
|
||||||
|
|
||||||
ratings = coo_matrix((data["rating"], (row, col)),
|
ratings = coo_matrix((data["rating"], (row, col)),
|
||||||
shape=(n_user, n_item))
|
shape=(n_user, n_item))
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,11 @@ import torch
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Modifications:
|
||||||
|
#
|
||||||
|
# 1. Fix issue where embed_user.shape is deprecated, embed_user.weight.shape instead. Line 91
|
||||||
|
# It may be casued by torch.Embedding update.
|
||||||
|
# (2024-9-8, czzhangheng)
|
||||||
|
|
||||||
def wrap_MFTrainer(base_trainer: Type[MFTrainer]) -> Type[MFTrainer]:
|
def wrap_MFTrainer(base_trainer: Type[MFTrainer]) -> Type[MFTrainer]:
|
||||||
"""Build `SGDMFTrainer` with a plug-in manner, by registering new
|
"""Build `SGDMFTrainer` with a plug-in manner, by registering new
|
||||||
|
|
@ -74,26 +79,39 @@ def hook_on_batch_backward(ctx):
|
||||||
ctx.optimizer.zero_grad()
|
ctx.optimizer.zero_grad()
|
||||||
ctx.loss_task.backward()
|
ctx.loss_task.backward()
|
||||||
|
|
||||||
|
if ctx.model.embed_user.weight.grad.is_sparse:
|
||||||
|
dense_user_grad = ctx.model.embed_user.weight.grad.to_dense()
|
||||||
|
else:
|
||||||
|
dense_user_grad = ctx.model.embed_user.weight.grad
|
||||||
|
|
||||||
|
if ctx.model.embed_item.weight.grad.is_sparse:
|
||||||
|
dense_item_grad = ctx.model.embed_item.weight.grad.to_dense()
|
||||||
|
else:
|
||||||
|
dense_item_grad = ctx.model.embed_item.weight.grad
|
||||||
|
|
||||||
# Inject noise
|
# Inject noise
|
||||||
ctx.model.embed_user.grad.data += get_random(
|
dense_user_grad.data += get_random(
|
||||||
"Normal",
|
"Normal",
|
||||||
sample_shape=ctx.model.embed_user.shape,
|
sample_shape=ctx.model.embed_user.weight.shape,
|
||||||
params={
|
params={
|
||||||
"loc": 0,
|
"loc": 0,
|
||||||
"scale": ctx.scale
|
"scale": ctx.scale
|
||||||
},
|
},
|
||||||
device=ctx.model.embed_user.device)
|
device=ctx.model.embed_user.weight.device)
|
||||||
ctx.model.embed_item.grad.data += get_random(
|
dense_item_grad.data += get_random(
|
||||||
"Normal",
|
"Normal",
|
||||||
sample_shape=ctx.model.embed_item.shape,
|
sample_shape=ctx.model.embed_item.weight.shape,
|
||||||
params={
|
params={
|
||||||
"loc": 0,
|
"loc": 0,
|
||||||
"scale": ctx.scale
|
"scale": ctx.scale
|
||||||
},
|
},
|
||||||
device=ctx.model.embed_item.device)
|
device=ctx.model.embed_item.weight.device)
|
||||||
|
|
||||||
|
ctx.model.embed_user.weight.grad = dense_user_grad.to_sparse()
|
||||||
|
ctx.model.embed_item.weight.grad = dense_item_grad.to_sparse()
|
||||||
ctx.optimizer.step()
|
ctx.optimizer.step()
|
||||||
|
|
||||||
# Embedding clipping
|
# Embedding clipping
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
embedding_clip(ctx.model.embed_user, ctx.sgdmf_R)
|
embedding_clip(ctx.model.embed_user.weight, ctx.sgdmf_R)
|
||||||
embedding_clip(ctx.model.embed_item, ctx.sgdmf_R)
|
embedding_clip(ctx.model.embed_item.weight, ctx.sgdmf_R)
|
||||||
|
|
|
||||||
8
setup.py
8
setup.py
|
|
@ -2,13 +2,17 @@ from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import setuptools
|
import setuptools
|
||||||
|
|
||||||
|
# Modifications:
|
||||||
|
# 1. Change the package requirement for suit latest package
|
||||||
|
# (2024-10-10, czzhangheng)
|
||||||
|
|
||||||
__name__ = 'federatedscope'
|
__name__ = 'federatedscope'
|
||||||
__version__ = '0.3.0'
|
__version__ = '0.3.0'
|
||||||
URL = 'https://github.com/alibaba/FederatedScope'
|
URL = 'https://github.com/alibaba/FederatedScope'
|
||||||
|
|
||||||
minimal_requires = [
|
minimal_requires = [
|
||||||
'numpy<1.23.0', 'scikit-learn==1.0.2', 'scipy==1.7.3', 'pandas',
|
'numpy', 'scikit-learn', 'scipy', 'pandas',
|
||||||
'grpcio>=1.45.0', 'grpcio-tools', 'pyyaml>=5.1', 'fvcore', 'iopath',
|
'grpcio', 'grpcio-tools', 'pyyaml>=5.1', 'fvcore', 'iopath',
|
||||||
'wandb', 'tensorboard', 'tensorboardX', 'pympler', 'protobuf==3.19.4',
|
'wandb', 'tensorboard', 'tensorboardX', 'pympler', 'protobuf==3.19.4',
|
||||||
'matplotlib'
|
'matplotlib'
|
||||||
]
|
]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue