FS-TFP/materials/notebook/02_start_your_own_case.ipynb

541 lines
15 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# Start with your own case \n",
"\n",
"In addition to the rich collcetion of **datasets**, **models** and **evaluation metrics**, **FederatedScope** also allows to create your own or introduce more to our package.\n",
"\n",
"We provide `register` function to help build your own federated learning workflow. This introduction will help you to start with your own case:\n",
"\n",
"1. [Load a dataset](#data)\n",
"2. [Build a model](#model) \n",
"3. [Create a trainer](#trainer)\n",
"4. [Introduce more evaluation metrics](#metric)"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## <span id=\"data\">1. Load a dataset</span>\n",
"\n",
"We provide a function federatedscope.register.register_data to make your dataset available with three steps:\n",
"\n",
"* Step1: set up your data in the following format (standalone):\n",
" \n",
" **Note**: This function returns a `dict`, where the `key` is the client's id, and the `value` is the data `dict` of each client with 'train', 'test' or 'val'. You can also modify the config here.\n",
"\n",
" We take `torchvision.datasets.MNIST`, which is split and assigned to two clients, as an example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-31T10:13:12.308497Z",
"start_time": "2022-03-31T10:13:12.302160Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"def load_my_data(config):\n",
" import numpy as np\n",
" from torchvision import transforms\n",
" from torchvision.datasets import MNIST\n",
" from torch.utils.data import DataLoader\n",
"\n",
" # Build data\n",
" transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.9637], std=[0.1592])\n",
" ])\n",
" data_train = MNIST(root=config.data.root, train=True, transform=transform, download=True)\n",
" data_test = MNIST(root=config.data.root, train=False, transform=transform, download=True)\n",
"\n",
" # Split data into dict\n",
" data_dict = dict()\n",
" train_per_client = len(data_train) // config.federate.client_num\n",
" test_per_client = len(data_test) // config.federate.client_num\n",
"\n",
" for client_idx in range(1, config.federate.client_num + 1):\n",
" dataloader_dict = {\n",
" 'train':\n",
" DataLoader([\n",
" data_train[i]\n",
" for i in range((client_idx - 1) *\n",
" train_per_client, client_idx * train_per_client)\n",
" ],\n",
" config.data.batch_size,\n",
" shuffle=config.data.shuffle),\n",
" 'test':\n",
" DataLoader([\n",
" data_test[i]\n",
" for i in range((client_idx - 1) * test_per_client, client_idx *\n",
" test_per_client)\n",
" ],\n",
" config.data.batch_size,\n",
" shuffle=False)\n",
" }\n",
" data_dict[client_idx] = dataloader_dict\n",
"\n",
" return data_dict, config"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"* Step2: register your data with a keyword, such as `\"mydata\"`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-31T10:13:12.313727Z",
"start_time": "2022-03-31T10:13:12.309767Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from federatedscope.register import register_data\n",
"\n",
"def call_my_data(config, client_cfgs=None):\n",
" if config.data.type == \"mycvdata\":\n",
" data, modified_config = load_my_data(config)\n",
" return data, modified_config\n",
"\n",
"register_data(\"mycvdata\", call_my_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-31T09:29:07.854271Z",
"start_time": "2022-03-31T09:29:07.851771Z"
},
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## <span id=\"model\">2. Build a model</span>\n",
"We provide a function `federatedscope.register.register_model` to make your model available with three steps: (we take `ConvNet2` as an example)\n",
"\n",
"* Step1: build your model with Pytorch or Tensorflow and instantiate your model class with config and data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-31T10:13:12.787164Z",
"start_time": "2022-03-31T10:13:12.315611Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"class MyNet(torch.nn.Module):\n",
" def __init__(self,\n",
" in_channels,\n",
" h=32,\n",
" w=32,\n",
" hidden=2048,\n",
" class_num=10,\n",
" use_bn=True):\n",
" super(MyNet, self).__init__()\n",
" self.conv1 = torch.nn.Conv2d(in_channels, 32, 5, padding=2)\n",
" self.conv2 = torch.nn.Conv2d(32, 64, 5, padding=2)\n",
" self.fc1 = torch.nn.Linear((h // 2 // 2) * (w // 2 // 2) * 64, hidden)\n",
" self.fc2 = torch.nn.Linear(hidden, class_num)\n",
" self.relu = torch.nn.ReLU(inplace=True)\n",
" self.maxpool = torch.nn.MaxPool2d(2)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = self.maxpool(self.relu(x))\n",
" x = self.conv2(x)\n",
" x = self.maxpool(self.relu(x))\n",
" x = torch.nn.Flatten()(x)\n",
" x = self.relu(self.fc1(x))\n",
" x = self.fc2(x)\n",
" return x\n",
"\n",
"\n",
"def load_my_net(model_config, data_shape):\n",
" # You can also build models without local_data\n",
" model = MyNet(in_channels=data_shape[1],\n",
" h=data_shape[2],\n",
" w=data_shape[3],\n",
" hidden=model_config.hidden,\n",
" class_num=model_config.out_channels)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"* Step2: register your model with a keyword, such as `\"mynet\"`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-31T10:13:12.791526Z",
"start_time": "2022-03-31T10:13:12.788549Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from federatedscope.register import register_model\n",
"\n",
"def call_my_net(model_config, data_shape):\n",
" if model_config.type == \"mycnn\":\n",
" model = load_my_net(model_config, data_shape)\n",
" return model\n",
"\n",
"register_model(\"mycnn\", call_my_net)"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-31T09:29:10.271414Z",
"start_time": "2022-03-31T09:29:10.269302Z"
},
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## <span id=\"trainer\">3. Create a trainer</span>\n",
"\n",
"FederatedScope decouples the local learning process and details of FL communication and schedule, allowing users to freely customize the local learning algorithms via the `Trainer`. We recommend user build trainer by inheriting `federatedscope.core.trainers.trainer.GeneralTorchTrainer`, for more details, please see [Trainer](https://federatedscope.io/docs/trainer/). Similarly, we provide `federatedscope.register.register_trainer` to make your customized trainer available:\n",
"\n",
"* Step1: build your trainer by inheriting `GeneralTrainer`. Our `GeneralTrainer` already supports many different usages, for the advanced user, please see [federatedscope.core.trainers.trainer.GeneralTrainer]() for more details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-31T10:13:13.181147Z",
"start_time": "2022-03-31T10:13:12.792631Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from federatedscope.core.trainers import GeneralTorchTrainer\n",
"\n",
"class MyTrainer(GeneralTorchTrainer):\n",
" pass"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"* Step2: register your trainer with a keyword, such as `\"mytrainer\"`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-31T10:13:13.185648Z",
"start_time": "2022-03-31T10:13:13.182604Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from federatedscope.register import register_trainer\n",
"\n",
"def call_my_trainer(trainer_type):\n",
" if trainer_type == 'mycvtrainer':\n",
" trainer_builder = MyTrainer\n",
" return trainer_builder\n",
"\n",
"register_trainer('mycvtrainer', call_my_trainer)"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## <span id=\"metric\">4. Introduce more evaluation metrics</span>\n",
"We provide a number of metrics to monitor the entire federal learning process. You just need to list the name of the metric you want in `cfg.eval.metrics`. We currently support metrics such as loss, accuracy, etc. (See [federatedscope.core.evaluator](federatedscope/core/evaluator.py) for more details).\n",
"\n",
"We also provide a function `federatedscope.register.register_metric` to make your evaluation metrics available with three steps:\n",
"\n",
"* Step1: build your metric (see [federatedscope.core.context](federatedscope/core/context.py) for more about `ctx`)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-31T10:13:13.189195Z",
"start_time": "2022-03-31T10:13:13.187033Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"def cal_my_metric(ctx, **kwargs):\n",
" return ctx[\"num_train_data\"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"* Step2: register your metric with a keyword, such as `\"mymetric\"`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-31T10:13:13.193453Z",
"start_time": "2022-03-31T10:13:13.190519Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from federatedscope.register import register_metric\n",
"\n",
"def call_my_metric(types):\n",
" if \"mymetric\" in types:\n",
" metric_builder = cal_my_metric\n",
" return \"mymetric\", metric_builder\n",
"\n",
"register_metric(\"mymetric\", call_my_metric)"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Let's start!\n",
"* Set your data, model, trainer and metric first."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-31T10:13:13.219711Z",
"start_time": "2022-03-31T10:13:13.195532Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from federatedscope.core.configs.config import global_cfg\n",
"\n",
"cfg = global_cfg.clone()\n",
"\n",
"cfg.data.type = 'mycvdata'\n",
"cfg.data.root = 'data'\n",
"cfg.data.transform = [['ToTensor'], ['Normalize', {'mean': [0.1307], 'std': [0.3081]}]]\n",
"cfg.model.type = 'mycnn'\n",
"cfg.model.out_channels = 10\n",
"cfg.trainer.type = 'mycvtrainer'\n",
"cfg.eval.metric = ['mymetric']"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"* Configure other options in `cfg`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2022-03-31T10:13:13.225301Z",
"start_time": "2022-03-31T10:13:13.221148Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"cfg.use_gpu = False\n",
"cfg.best_res_update_round_wise_key = \"test_loss\"\n",
"\n",
"cfg.federate.mode = 'standalone'\n",
"cfg.federate.local_update_steps = 5\n",
"cfg.federate.total_round_num = 20\n",
"cfg.federate.sample_client_num = 5\n",
"cfg.federate.client_num = 5\n",
"\n",
"cfg.train.optimizer.lr = 0.001\n",
"cfg.train.optimizer.weight_decay = 0.0\n",
"cfg.grad.grad_clip = 5.0\n",
"\n",
"cfg.criterion.type = 'CrossEntropyLoss'\n",
"cfg.seed = 123\n",
"cfg.eval.best_res_update_round_wise_key = \"test_loss\""
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"* Start your FL process!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2022-03-31T10:13:12.142Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"from federatedscope.core.auxiliaries.data_builder import get_data\n",
"from federatedscope.core.auxiliaries.utils import setup_seed\n",
"from federatedscope.core.auxiliaries.logging import update_logger\n",
"from federatedscope.core.fed_runner import FedRunner\n",
"from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls\n",
"\n",
"setup_seed(cfg.seed)\n",
"update_logger(cfg)\n",
"data, modified_cfg = get_data(cfg)\n",
"cfg.merge_from_other_cfg(modified_cfg)\n",
"Fed_runner = FedRunner(data=data,\n",
" server_class=get_server_cls(cfg),\n",
" client_class=get_client_cls(cfg),\n",
" config=cfg.clone())\n",
"Fed_runner.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 4
}