{ "cells": [ { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Start with your own case \n", "\n", "In addition to the rich collcetion of **datasets**, **models** and **evaluation metrics**, **FederatedScope** also allows to create your own or introduce more to our package.\n", "\n", "We provide `register` function to help build your own federated learning workflow. This introduction will help you to start with your own case:\n", "\n", "1. [Load a dataset](#data)\n", "2. [Build a model](#model) \n", "3. [Create a trainer](#trainer)\n", "4. [Introduce more evaluation metrics](#metric)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 1. Load a dataset\n", "\n", "We provide a function federatedscope.register.register_data to make your dataset available with three steps:\n", "\n", "* Step1: set up your data in the following format (standalone):\n", " \n", " **Note**: This function returns a `dict`, where the `key` is the client's id, and the `value` is the data `dict` of each client with 'train', 'test' or 'val'. You can also modify the config here.\n", "\n", " We take `torchvision.datasets.MNIST`, which is split and assigned to two clients, as an example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2022-03-31T10:13:12.308497Z", "start_time": "2022-03-31T10:13:12.302160Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "def load_my_data(config):\n", " import numpy as np\n", " from torchvision import transforms\n", " from torchvision.datasets import MNIST\n", " from torch.utils.data import DataLoader\n", "\n", " # Build data\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.9637], std=[0.1592])\n", " ])\n", " data_train = MNIST(root=config.data.root, train=True, transform=transform, download=True)\n", " data_test = MNIST(root=config.data.root, train=False, transform=transform, download=True)\n", "\n", " # Split data into dict\n", " data_dict = dict()\n", " train_per_client = len(data_train) // config.federate.client_num\n", " test_per_client = len(data_test) // config.federate.client_num\n", "\n", " for client_idx in range(1, config.federate.client_num + 1):\n", " dataloader_dict = {\n", " 'train':\n", " DataLoader([\n", " data_train[i]\n", " for i in range((client_idx - 1) *\n", " train_per_client, client_idx * train_per_client)\n", " ],\n", " config.data.batch_size,\n", " shuffle=config.data.shuffle),\n", " 'test':\n", " DataLoader([\n", " data_test[i]\n", " for i in range((client_idx - 1) * test_per_client, client_idx *\n", " test_per_client)\n", " ],\n", " config.data.batch_size,\n", " shuffle=False)\n", " }\n", " data_dict[client_idx] = dataloader_dict\n", "\n", " return data_dict, config" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "* Step2: register your data with a keyword, such as `\"mydata\"`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2022-03-31T10:13:12.313727Z", "start_time": "2022-03-31T10:13:12.309767Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from federatedscope.register import register_data\n", "\n", "def call_my_data(config, client_cfgs=None):\n", " if config.data.type == \"mycvdata\":\n", " data, modified_config = load_my_data(config)\n", " return data, modified_config\n", "\n", "register_data(\"mycvdata\", call_my_data)" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2022-03-31T09:29:07.854271Z", "start_time": "2022-03-31T09:29:07.851771Z" }, "pycharm": { "name": "#%% md\n" } }, "source": [ "## 2. Build a model\n", "We provide a function `federatedscope.register.register_model` to make your model available with three steps: (we take `ConvNet2` as an example)\n", "\n", "* Step1: build your model with Pytorch or Tensorflow and instantiate your model class with config and data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2022-03-31T10:13:12.787164Z", "start_time": "2022-03-31T10:13:12.315611Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "class MyNet(torch.nn.Module):\n", " def __init__(self,\n", " in_channels,\n", " h=32,\n", " w=32,\n", " hidden=2048,\n", " class_num=10,\n", " use_bn=True):\n", " super(MyNet, self).__init__()\n", " self.conv1 = torch.nn.Conv2d(in_channels, 32, 5, padding=2)\n", " self.conv2 = torch.nn.Conv2d(32, 64, 5, padding=2)\n", " self.fc1 = torch.nn.Linear((h // 2 // 2) * (w // 2 // 2) * 64, hidden)\n", " self.fc2 = torch.nn.Linear(hidden, class_num)\n", " self.relu = torch.nn.ReLU(inplace=True)\n", " self.maxpool = torch.nn.MaxPool2d(2)\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = self.maxpool(self.relu(x))\n", " x = self.conv2(x)\n", " x = self.maxpool(self.relu(x))\n", " x = torch.nn.Flatten()(x)\n", " x = self.relu(self.fc1(x))\n", " x = self.fc2(x)\n", " return x\n", "\n", "\n", "def load_my_net(model_config, data_shape):\n", " # You can also build models without local_data\n", " model = MyNet(in_channels=data_shape[1],\n", " h=data_shape[2],\n", " w=data_shape[3],\n", " hidden=model_config.hidden,\n", " class_num=model_config.out_channels)\n", " return model" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "* Step2: register your model with a keyword, such as `\"mynet\"`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2022-03-31T10:13:12.791526Z", "start_time": "2022-03-31T10:13:12.788549Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from federatedscope.register import register_model\n", "\n", "def call_my_net(model_config, data_shape):\n", " if model_config.type == \"mycnn\":\n", " model = load_my_net(model_config, data_shape)\n", " return model\n", "\n", "register_model(\"mycnn\", call_my_net)" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2022-03-31T09:29:10.271414Z", "start_time": "2022-03-31T09:29:10.269302Z" }, "pycharm": { "name": "#%% md\n" } }, "source": [ "## 3. Create a trainer\n", "\n", "FederatedScope decouples the local learning process and details of FL communication and schedule, allowing users to freely customize the local learning algorithms via the `Trainer`. We recommend user build trainer by inheriting `federatedscope.core.trainers.trainer.GeneralTorchTrainer`, for more details, please see [Trainer](https://federatedscope.io/docs/trainer/). Similarly, we provide `federatedscope.register.register_trainer` to make your customized trainer available:\n", "\n", "* Step1: build your trainer by inheriting `GeneralTrainer`. Our `GeneralTrainer` already supports many different usages, for the advanced user, please see [federatedscope.core.trainers.trainer.GeneralTrainer]() for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2022-03-31T10:13:13.181147Z", "start_time": "2022-03-31T10:13:12.792631Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from federatedscope.core.trainers import GeneralTorchTrainer\n", "\n", "class MyTrainer(GeneralTorchTrainer):\n", " pass" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "* Step2: register your trainer with a keyword, such as `\"mytrainer\"`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2022-03-31T10:13:13.185648Z", "start_time": "2022-03-31T10:13:13.182604Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from federatedscope.register import register_trainer\n", "\n", "def call_my_trainer(trainer_type):\n", " if trainer_type == 'mycvtrainer':\n", " trainer_builder = MyTrainer\n", " return trainer_builder\n", "\n", "register_trainer('mycvtrainer', call_my_trainer)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 4. Introduce more evaluation metrics\n", "We provide a number of metrics to monitor the entire federal learning process. You just need to list the name of the metric you want in `cfg.eval.metrics`. We currently support metrics such as loss, accuracy, etc. (See [federatedscope.core.evaluator](federatedscope/core/evaluator.py) for more details).\n", "\n", "We also provide a function `federatedscope.register.register_metric` to make your evaluation metrics available with three steps:\n", "\n", "* Step1: build your metric (see [federatedscope.core.context](federatedscope/core/context.py) for more about `ctx`)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2022-03-31T10:13:13.189195Z", "start_time": "2022-03-31T10:13:13.187033Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "def cal_my_metric(ctx, **kwargs):\n", " return ctx[\"num_train_data\"]" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "* Step2: register your metric with a keyword, such as `\"mymetric\"`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2022-03-31T10:13:13.193453Z", "start_time": "2022-03-31T10:13:13.190519Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from federatedscope.register import register_metric\n", "\n", "def call_my_metric(types):\n", " if \"mymetric\" in types:\n", " metric_builder = cal_my_metric\n", " return \"mymetric\", metric_builder\n", "\n", "register_metric(\"mymetric\", call_my_metric)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Let's start!\n", "* Set your data, model, trainer and metric first." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2022-03-31T10:13:13.219711Z", "start_time": "2022-03-31T10:13:13.195532Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from federatedscope.core.configs.config import global_cfg\n", "\n", "cfg = global_cfg.clone()\n", "\n", "cfg.data.type = 'mycvdata'\n", "cfg.data.root = 'data'\n", "cfg.data.transform = [['ToTensor'], ['Normalize', {'mean': [0.1307], 'std': [0.3081]}]]\n", "cfg.model.type = 'mycnn'\n", "cfg.model.out_channels = 10\n", "cfg.trainer.type = 'mycvtrainer'\n", "cfg.eval.metric = ['mymetric']" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "* Configure other options in `cfg`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2022-03-31T10:13:13.225301Z", "start_time": "2022-03-31T10:13:13.221148Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "cfg.use_gpu = False\n", "cfg.best_res_update_round_wise_key = \"test_loss\"\n", "\n", "cfg.federate.mode = 'standalone'\n", "cfg.federate.local_update_steps = 5\n", "cfg.federate.total_round_num = 20\n", "cfg.federate.sample_client_num = 5\n", "cfg.federate.client_num = 5\n", "\n", "cfg.train.optimizer.lr = 0.001\n", "cfg.train.optimizer.weight_decay = 0.0\n", "cfg.grad.grad_clip = 5.0\n", "\n", "cfg.criterion.type = 'CrossEntropyLoss'\n", "cfg.seed = 123\n", "cfg.eval.best_res_update_round_wise_key = \"test_loss\"" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "* Start your FL process!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "start_time": "2022-03-31T10:13:12.142Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from federatedscope.core.auxiliaries.data_builder import get_data\n", "from federatedscope.core.auxiliaries.utils import setup_seed\n", "from federatedscope.core.auxiliaries.logging import update_logger\n", "from federatedscope.core.fed_runner import FedRunner\n", "from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls\n", "\n", "setup_seed(cfg.seed)\n", "update_logger(cfg)\n", "data, modified_cfg = get_data(cfg)\n", "cfg.merge_from_other_cfg(modified_cfg)\n", "Fed_runner = FedRunner(data=data,\n", " server_class=get_server_cls(cfg),\n", " client_class=get_client_cls(cfg),\n", " config=cfg.clone())\n", "Fed_runner.run()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 4 }