541 lines
15 KiB
Plaintext
541 lines
15 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"# Start with your own case \n",
|
|
"\n",
|
|
"In addition to the rich collcetion of **datasets**, **models** and **evaluation metrics**, **FederatedScope** also allows to create your own or introduce more to our package.\n",
|
|
"\n",
|
|
"We provide `register` function to help build your own federated learning workflow. This introduction will help you to start with your own case:\n",
|
|
"\n",
|
|
"1. [Load a dataset](#data)\n",
|
|
"2. [Build a model](#model) \n",
|
|
"3. [Create a trainer](#trainer)\n",
|
|
"4. [Introduce more evaluation metrics](#metric)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"## <span id=\"data\">1. Load a dataset</span>\n",
|
|
"\n",
|
|
"We provide a function federatedscope.register.register_data to make your dataset available with three steps:\n",
|
|
"\n",
|
|
"* Step1: set up your data in the following format (standalone):\n",
|
|
" \n",
|
|
" **Note**: This function returns a `dict`, where the `key` is the client's id, and the `value` is the data `dict` of each client with 'train', 'test' or 'val'. You can also modify the config here.\n",
|
|
"\n",
|
|
" We take `torchvision.datasets.MNIST`, which is split and assigned to two clients, as an example:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-03-31T10:13:12.308497Z",
|
|
"start_time": "2022-03-31T10:13:12.302160Z"
|
|
},
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def load_my_data(config):\n",
|
|
" import numpy as np\n",
|
|
" from torchvision import transforms\n",
|
|
" from torchvision.datasets import MNIST\n",
|
|
" from torch.utils.data import DataLoader\n",
|
|
"\n",
|
|
" # Build data\n",
|
|
" transform = transforms.Compose([\n",
|
|
" transforms.ToTensor(),\n",
|
|
" transforms.Normalize(mean=[0.9637], std=[0.1592])\n",
|
|
" ])\n",
|
|
" data_train = MNIST(root=config.data.root, train=True, transform=transform, download=True)\n",
|
|
" data_test = MNIST(root=config.data.root, train=False, transform=transform, download=True)\n",
|
|
"\n",
|
|
" # Split data into dict\n",
|
|
" data_dict = dict()\n",
|
|
" train_per_client = len(data_train) // config.federate.client_num\n",
|
|
" test_per_client = len(data_test) // config.federate.client_num\n",
|
|
"\n",
|
|
" for client_idx in range(1, config.federate.client_num + 1):\n",
|
|
" dataloader_dict = {\n",
|
|
" 'train':\n",
|
|
" DataLoader([\n",
|
|
" data_train[i]\n",
|
|
" for i in range((client_idx - 1) *\n",
|
|
" train_per_client, client_idx * train_per_client)\n",
|
|
" ],\n",
|
|
" config.data.batch_size,\n",
|
|
" shuffle=config.data.shuffle),\n",
|
|
" 'test':\n",
|
|
" DataLoader([\n",
|
|
" data_test[i]\n",
|
|
" for i in range((client_idx - 1) * test_per_client, client_idx *\n",
|
|
" test_per_client)\n",
|
|
" ],\n",
|
|
" config.data.batch_size,\n",
|
|
" shuffle=False)\n",
|
|
" }\n",
|
|
" data_dict[client_idx] = dataloader_dict\n",
|
|
"\n",
|
|
" return data_dict, config"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"* Step2: register your data with a keyword, such as `\"mydata\"`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-03-31T10:13:12.313727Z",
|
|
"start_time": "2022-03-31T10:13:12.309767Z"
|
|
},
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from federatedscope.register import register_data\n",
|
|
"\n",
|
|
"def call_my_data(config, client_cfgs=None):\n",
|
|
" if config.data.type == \"mycvdata\":\n",
|
|
" data, modified_config = load_my_data(config)\n",
|
|
" return data, modified_config\n",
|
|
"\n",
|
|
"register_data(\"mycvdata\", call_my_data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-03-31T09:29:07.854271Z",
|
|
"start_time": "2022-03-31T09:29:07.851771Z"
|
|
},
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"## <span id=\"model\">2. Build a model</span>\n",
|
|
"We provide a function `federatedscope.register.register_model` to make your model available with three steps: (we take `ConvNet2` as an example)\n",
|
|
"\n",
|
|
"* Step1: build your model with Pytorch or Tensorflow and instantiate your model class with config and data."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-03-31T10:13:12.787164Z",
|
|
"start_time": "2022-03-31T10:13:12.315611Z"
|
|
},
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"\n",
|
|
"\n",
|
|
"class MyNet(torch.nn.Module):\n",
|
|
" def __init__(self,\n",
|
|
" in_channels,\n",
|
|
" h=32,\n",
|
|
" w=32,\n",
|
|
" hidden=2048,\n",
|
|
" class_num=10,\n",
|
|
" use_bn=True):\n",
|
|
" super(MyNet, self).__init__()\n",
|
|
" self.conv1 = torch.nn.Conv2d(in_channels, 32, 5, padding=2)\n",
|
|
" self.conv2 = torch.nn.Conv2d(32, 64, 5, padding=2)\n",
|
|
" self.fc1 = torch.nn.Linear((h // 2 // 2) * (w // 2 // 2) * 64, hidden)\n",
|
|
" self.fc2 = torch.nn.Linear(hidden, class_num)\n",
|
|
" self.relu = torch.nn.ReLU(inplace=True)\n",
|
|
" self.maxpool = torch.nn.MaxPool2d(2)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" x = self.conv1(x)\n",
|
|
" x = self.maxpool(self.relu(x))\n",
|
|
" x = self.conv2(x)\n",
|
|
" x = self.maxpool(self.relu(x))\n",
|
|
" x = torch.nn.Flatten()(x)\n",
|
|
" x = self.relu(self.fc1(x))\n",
|
|
" x = self.fc2(x)\n",
|
|
" return x\n",
|
|
"\n",
|
|
"\n",
|
|
"def load_my_net(model_config, data_shape):\n",
|
|
" # You can also build models without local_data\n",
|
|
" model = MyNet(in_channels=data_shape[1],\n",
|
|
" h=data_shape[2],\n",
|
|
" w=data_shape[3],\n",
|
|
" hidden=model_config.hidden,\n",
|
|
" class_num=model_config.out_channels)\n",
|
|
" return model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"* Step2: register your model with a keyword, such as `\"mynet\"`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-03-31T10:13:12.791526Z",
|
|
"start_time": "2022-03-31T10:13:12.788549Z"
|
|
},
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from federatedscope.register import register_model\n",
|
|
"\n",
|
|
"def call_my_net(model_config, data_shape):\n",
|
|
" if model_config.type == \"mycnn\":\n",
|
|
" model = load_my_net(model_config, data_shape)\n",
|
|
" return model\n",
|
|
"\n",
|
|
"register_model(\"mycnn\", call_my_net)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-03-31T09:29:10.271414Z",
|
|
"start_time": "2022-03-31T09:29:10.269302Z"
|
|
},
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"## <span id=\"trainer\">3. Create a trainer</span>\n",
|
|
"\n",
|
|
"FederatedScope decouples the local learning process and details of FL communication and schedule, allowing users to freely customize the local learning algorithms via the `Trainer`. We recommend user build trainer by inheriting `federatedscope.core.trainers.trainer.GeneralTorchTrainer`, for more details, please see [Trainer](https://federatedscope.io/docs/trainer/). Similarly, we provide `federatedscope.register.register_trainer` to make your customized trainer available:\n",
|
|
"\n",
|
|
"* Step1: build your trainer by inheriting `GeneralTrainer`. Our `GeneralTrainer` already supports many different usages, for the advanced user, please see [federatedscope.core.trainers.trainer.GeneralTrainer]() for more details."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-03-31T10:13:13.181147Z",
|
|
"start_time": "2022-03-31T10:13:12.792631Z"
|
|
},
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from federatedscope.core.trainers import GeneralTorchTrainer\n",
|
|
"\n",
|
|
"class MyTrainer(GeneralTorchTrainer):\n",
|
|
" pass"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"* Step2: register your trainer with a keyword, such as `\"mytrainer\"`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-03-31T10:13:13.185648Z",
|
|
"start_time": "2022-03-31T10:13:13.182604Z"
|
|
},
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from federatedscope.register import register_trainer\n",
|
|
"\n",
|
|
"def call_my_trainer(trainer_type):\n",
|
|
" if trainer_type == 'mycvtrainer':\n",
|
|
" trainer_builder = MyTrainer\n",
|
|
" return trainer_builder\n",
|
|
"\n",
|
|
"register_trainer('mycvtrainer', call_my_trainer)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"## <span id=\"metric\">4. Introduce more evaluation metrics</span>\n",
|
|
"We provide a number of metrics to monitor the entire federal learning process. You just need to list the name of the metric you want in `cfg.eval.metrics`. We currently support metrics such as loss, accuracy, etc. (See [federatedscope.core.evaluator](federatedscope/core/evaluator.py) for more details).\n",
|
|
"\n",
|
|
"We also provide a function `federatedscope.register.register_metric` to make your evaluation metrics available with three steps:\n",
|
|
"\n",
|
|
"* Step1: build your metric (see [federatedscope.core.context](federatedscope/core/context.py) for more about `ctx`)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-03-31T10:13:13.189195Z",
|
|
"start_time": "2022-03-31T10:13:13.187033Z"
|
|
},
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def cal_my_metric(ctx, **kwargs):\n",
|
|
" return ctx[\"num_train_data\"]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"* Step2: register your metric with a keyword, such as `\"mymetric\"`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-03-31T10:13:13.193453Z",
|
|
"start_time": "2022-03-31T10:13:13.190519Z"
|
|
},
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from federatedscope.register import register_metric\n",
|
|
"\n",
|
|
"def call_my_metric(types):\n",
|
|
" if \"mymetric\" in types:\n",
|
|
" metric_builder = cal_my_metric\n",
|
|
" return \"mymetric\", metric_builder\n",
|
|
"\n",
|
|
"register_metric(\"mymetric\", call_my_metric)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"## Let's start!\n",
|
|
"* Set your data, model, trainer and metric first."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-03-31T10:13:13.219711Z",
|
|
"start_time": "2022-03-31T10:13:13.195532Z"
|
|
},
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from federatedscope.core.configs.config import global_cfg\n",
|
|
"\n",
|
|
"cfg = global_cfg.clone()\n",
|
|
"\n",
|
|
"cfg.data.type = 'mycvdata'\n",
|
|
"cfg.data.root = 'data'\n",
|
|
"cfg.data.transform = [['ToTensor'], ['Normalize', {'mean': [0.1307], 'std': [0.3081]}]]\n",
|
|
"cfg.model.type = 'mycnn'\n",
|
|
"cfg.model.out_channels = 10\n",
|
|
"cfg.trainer.type = 'mycvtrainer'\n",
|
|
"cfg.eval.metric = ['mymetric']"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"* Configure other options in `cfg`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2022-03-31T10:13:13.225301Z",
|
|
"start_time": "2022-03-31T10:13:13.221148Z"
|
|
},
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"cfg.use_gpu = False\n",
|
|
"cfg.best_res_update_round_wise_key = \"test_loss\"\n",
|
|
"\n",
|
|
"cfg.federate.mode = 'standalone'\n",
|
|
"cfg.federate.local_update_steps = 5\n",
|
|
"cfg.federate.total_round_num = 20\n",
|
|
"cfg.federate.sample_client_num = 5\n",
|
|
"cfg.federate.client_num = 5\n",
|
|
"\n",
|
|
"cfg.train.optimizer.lr = 0.001\n",
|
|
"cfg.train.optimizer.weight_decay = 0.0\n",
|
|
"cfg.grad.grad_clip = 5.0\n",
|
|
"\n",
|
|
"cfg.criterion.type = 'CrossEntropyLoss'\n",
|
|
"cfg.seed = 123\n",
|
|
"cfg.eval.best_res_update_round_wise_key = \"test_loss\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
},
|
|
"source": [
|
|
"* Start your FL process!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"start_time": "2022-03-31T10:13:12.142Z"
|
|
},
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from federatedscope.core.auxiliaries.data_builder import get_data\n",
|
|
"from federatedscope.core.auxiliaries.utils import setup_seed\n",
|
|
"from federatedscope.core.auxiliaries.logging import update_logger\n",
|
|
"from federatedscope.core.fed_runner import FedRunner\n",
|
|
"from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls\n",
|
|
"\n",
|
|
"setup_seed(cfg.seed)\n",
|
|
"update_logger(cfg)\n",
|
|
"data, modified_cfg = get_data(cfg)\n",
|
|
"cfg.merge_from_other_cfg(modified_cfg)\n",
|
|
"Fed_runner = FedRunner(data=data,\n",
|
|
" server_class=get_server_cls(cfg),\n",
|
|
" client_class=get_client_cls(cfg),\n",
|
|
" config=cfg.clone())\n",
|
|
"Fed_runner.run()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.13"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|