207 lines
6.8 KiB
Python
207 lines
6.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import copy
|
|
from math import pi, cos, e
|
|
import numpy as np
|
|
from collections import OrderedDict
|
|
from federatedscope.contrib.model.resnet import BasicBlock, Bottleneck
|
|
|
|
|
|
# Model class
|
|
class ResNet(nn.Module):
|
|
def __init__(self, block, num_blocks, num_classes=10, cfg=None):
|
|
super(ResNet, self).__init__()
|
|
self.train_sup = (num_classes > 0)
|
|
self.in_planes = 64
|
|
|
|
self.conv1 = nn.Conv2d(3,
|
|
64,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
bias=False)
|
|
self.bn1 = nn.BatchNorm2d(64, affine=True)
|
|
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
|
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
|
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
|
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
|
self.output_dim = 512 * block.expansion
|
|
if (self.train_sup):
|
|
self.linear = nn.Linear(512 * block.expansion, num_classes)
|
|
|
|
def _make_layer(self, block, planes, num_blocks, stride):
|
|
strides = [stride] + [1] * (num_blocks - 1)
|
|
layers = []
|
|
for stride in strides:
|
|
layers.append(block(self.in_planes, planes, stride))
|
|
self.in_planes = planes * block.expansion
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
out = F.relu(self.bn1(self.conv1(x)))
|
|
out = self.layer1(out)
|
|
out = self.layer2(out)
|
|
out = self.layer3(out)
|
|
out = self.layer4(out)
|
|
out = F.adaptive_avg_pool2d(out, (1, 1))
|
|
out = out.view(out.size(0), -1)
|
|
if (self.train_sup):
|
|
out = self.linear(out)
|
|
return out
|
|
|
|
|
|
class ResNet_basic(nn.Module):
|
|
def __init__(self, block, num_blocks, num_classes=10, cfg=None):
|
|
super(ResNet_basic, self).__init__()
|
|
self.train_sup = (num_classes > 0)
|
|
|
|
self.in_planes = 16
|
|
self.conv1 = nn.Conv2d(3,
|
|
16,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
bias=False)
|
|
self.bn1 = nn.BatchNorm2d(16, affine=True)
|
|
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
|
|
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
|
|
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
|
|
self.output_dim = 512 * block.expansion
|
|
if (self.train_sup):
|
|
self.linear = nn.Linear(64 * block.expansion,
|
|
num_classes,
|
|
bias=True)
|
|
|
|
def _make_layer(self, block, planes, num_blocks, stride):
|
|
strides = [stride] + [1] * (num_blocks - 1)
|
|
layers = []
|
|
for stride in strides:
|
|
layers.append(block(self.in_planes, planes, stride))
|
|
self.in_planes = planes * block.expansion
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
out = F.relu(self.bn1(self.conv1(x)))
|
|
out = self.layer1(out)
|
|
out = self.layer2(out)
|
|
out = self.layer3(out)
|
|
out = F.adaptive_avg_pool2d(out, (1, 1))
|
|
out = out.view(out.size(0), -1)
|
|
if (self.train_sup):
|
|
out = self.linear(out)
|
|
return out
|
|
|
|
|
|
def get_block(block):
|
|
if (block == "BasicBlock"):
|
|
return BasicBlock
|
|
elif (block == "Bottleneck"):
|
|
return Bottleneck
|
|
|
|
|
|
def ResNet18(num_classes=10, block="BasicBlock"):
|
|
return ResNet(get_block(block), [2, 2, 2, 2], num_classes=num_classes)
|
|
|
|
|
|
def ResNet34(num_classes=10, block="BasicBlock"):
|
|
return ResNet(get_block(block), [3, 4, 6, 3], num_classes=num_classes)
|
|
|
|
|
|
def create_backbone(name, num_classes=10, block='BasicBlock'):
|
|
if (name == 'res18'):
|
|
net = ResNet18(num_classes=num_classes, block=block)
|
|
elif (name == 'res34'):
|
|
net = ResNet34(num_classes=num_classes, block=block)
|
|
|
|
return net
|
|
|
|
|
|
# Projector
|
|
class projection_MLP_simclr(nn.Module):
|
|
def __init__(self, in_dim, hidden_dim=512, out_dim=512):
|
|
super(projection_MLP_simclr, self).__init__()
|
|
self.layer1 = nn.Linear(in_dim, hidden_dim, bias=False)
|
|
self.layer1_bn = nn.BatchNorm1d(hidden_dim, affine=True)
|
|
self.layer2 = nn.Linear(hidden_dim, out_dim)
|
|
self.layer2_bn = nn.BatchNorm1d(out_dim, affine=False)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.layer1_bn(self.layer1(x)))
|
|
x = self.layer2_bn(self.layer2(x))
|
|
return x
|
|
|
|
|
|
# SimCLR
|
|
class simclr(nn.Module):
|
|
def __init__(self, bbone_arch):
|
|
super(simclr, self).__init__()
|
|
self.register_buffer("rounds_done", torch.zeros(1))
|
|
|
|
self.backbone = create_backbone(bbone_arch, num_classes=0)
|
|
self.projector = projection_MLP_simclr(self.backbone.output_dim,
|
|
hidden_dim=512,
|
|
out_dim=512)
|
|
|
|
def forward(self, x1, x2, x3=None, deg_labels=None):
|
|
z1, z2 = self.projector(self.backbone(x1)), self.projector(
|
|
self.backbone(x2))
|
|
|
|
return z1, z2
|
|
|
|
|
|
class simclr_linearprob(nn.Module):
|
|
def __init__(self, bbone_arch, num_classes=10):
|
|
super(simclr_linearprob, self).__init__()
|
|
self.register_buffer("rounds_done", torch.zeros(1))
|
|
|
|
self.backbone = create_backbone(bbone_arch, num_classes=0)
|
|
self.linear = nn.Linear(512, num_classes, bias=True)
|
|
|
|
def forward(self, x):
|
|
with torch.no_grad():
|
|
out = self.backbone(x)
|
|
out = self.linear(out)
|
|
|
|
return out
|
|
|
|
|
|
class simclr_supervised(nn.Module):
|
|
def __init__(self, bbone_arch, num_classes=10):
|
|
super(simclr_supervised, self).__init__()
|
|
self.register_buffer("rounds_done", torch.zeros(1))
|
|
|
|
self.backbone = create_backbone(bbone_arch, num_classes=0)
|
|
self.linear = nn.Linear(512, num_classes, bias=True)
|
|
|
|
def forward(self, x):
|
|
out = self.backbone(x)
|
|
out = self.linear(out)
|
|
|
|
return out
|
|
|
|
|
|
def ModelBuilder(model_config, local_data):
|
|
# You can also build models without local_data
|
|
if model_config.type == "SimCLR":
|
|
model = simclr(bbone_arch='res18')
|
|
return model
|
|
if model_config.type in ["SimCLR_linear"]:
|
|
model = simclr_linearprob(bbone_arch='res18', num_classes=10)
|
|
return model
|
|
if model_config.type in ["supervised_local", "supervised_fedavg"]:
|
|
model = simclr_supervised(bbone_arch='res18', num_classes=10)
|
|
return model
|
|
|
|
|
|
from federatedscope.register import register_model
|
|
|
|
|
|
def get_simclr(model_config, local_data):
|
|
model = ModelBuilder(model_config, local_data)
|
|
return model
|
|
|
|
|
|
register_model("SimCLR", get_simclr)
|
|
register_model("SimCLR_linear", get_simclr)
|