FS-TFP/federatedscope/cv/model/cnn.py

192 lines
5.3 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module
from torch.nn import Sequential
from torch.nn import Conv2d, BatchNorm2d
from torch.nn import Flatten
from torch.nn import Linear
from torch.nn import MaxPool2d
from torch.nn import ReLU
class ConvNet2(Module):
def __init__(self,
in_channels,
h=32,
w=32,
hidden=2048,
class_num=10,
use_bn=True,
dropout=.0):
super(ConvNet2, self).__init__()
self.conv1 = Conv2d(in_channels, 32, 5, padding=2)
self.conv2 = Conv2d(32, 64, 5, padding=2)
self.use_bn = use_bn
if use_bn:
self.bn1 = BatchNorm2d(32)
self.bn2 = BatchNorm2d(64)
self.fc1 = Linear((h // 2 // 2) * (w // 2 // 2) * 64, hidden)
self.fc2 = Linear(hidden, class_num)
self.relu = ReLU(inplace=True)
self.maxpool = MaxPool2d(2)
self.dropout = dropout
def forward(self, x):
x = self.bn1(self.conv1(x)) if self.use_bn else self.conv1(x)
x = self.maxpool(self.relu(x))
x = self.bn2(self.conv2(x)) if self.use_bn else self.conv2(x)
x = self.maxpool(self.relu(x))
x = Flatten()(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.relu(self.fc1(x))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc2(x)
return x
class ConvNet5(Module):
def __init__(self,
in_channels,
h=32,
w=32,
hidden=2048,
class_num=10,
dropout=.0):
super(ConvNet5, self).__init__()
self.conv1 = Conv2d(in_channels, 32, 5, padding=2)
self.bn1 = BatchNorm2d(32)
self.conv2 = Conv2d(32, 64, 5, padding=2)
self.bn2 = BatchNorm2d(64)
self.conv3 = Conv2d(64, 64, 5, padding=2)
self.bn3 = BatchNorm2d(64)
self.conv4 = Conv2d(64, 128, 5, padding=2)
self.bn4 = BatchNorm2d(128)
self.conv5 = Conv2d(128, 128, 5, padding=2)
self.bn5 = BatchNorm2d(128)
self.relu = ReLU(inplace=True)
self.maxpool = MaxPool2d(2)
self.fc1 = Linear(
(h // 2 // 2 // 2 // 2 // 2) * (w // 2 // 2 // 2 // 2 // 2) * 128,
hidden)
self.fc2 = Linear(hidden, class_num)
self.dropout = dropout
def forward(self, x):
x = self.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x)
x = self.relu(self.bn2(self.conv2(x)))
x = self.maxpool(x)
x = self.relu(self.bn3(self.conv3(x)))
x = self.maxpool(x)
x = self.relu(self.bn4(self.conv4(x)))
x = self.maxpool(x)
x = self.relu(self.bn5(self.conv5(x)))
x = self.maxpool(x)
x = Flatten()(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.relu(self.fc1(x))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc2(x)
return x
class VGG11(Module):
def __init__(self,
in_channels,
h=32,
w=32,
hidden=128,
class_num=10,
dropout=.0):
super(VGG11, self).__init__()
self.conv1 = Conv2d(in_channels, 64, 3, padding=1)
self.bn1 = BatchNorm2d(64)
self.conv2 = Conv2d(64, 128, 3, padding=1)
self.bn2 = BatchNorm2d(128)
self.conv3 = Conv2d(128, 256, 3, padding=1)
self.bn3 = BatchNorm2d(256)
self.conv4 = Conv2d(256, 256, 3, padding=1)
self.bn4 = BatchNorm2d(256)
self.conv5 = Conv2d(256, 512, 3, padding=1)
self.bn5 = BatchNorm2d(512)
self.conv6 = Conv2d(512, 512, 3, padding=1)
self.bn6 = BatchNorm2d(512)
self.conv7 = Conv2d(512, 512, 3, padding=1)
self.bn7 = BatchNorm2d(512)
self.conv8 = Conv2d(512, 512, 3, padding=1)
self.bn8 = BatchNorm2d(512)
self.relu = ReLU(inplace=True)
self.maxpool = MaxPool2d(2)
self.fc1 = Linear(
(h // 2 // 2 // 2 // 2 // 2) * (w // 2 // 2 // 2 // 2 // 2) * 512,
hidden)
self.fc2 = Linear(hidden, hidden)
self.fc3 = Linear(hidden, class_num)
self.dropout = dropout
def forward(self, x):
x = self.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x)
x = self.relu(self.bn2(self.conv2(x)))
x = self.maxpool(x)
x = self.relu(self.bn3(self.conv3(x)))
x = self.maxpool(x)
x = self.relu(self.bn4(self.conv4(x)))
x = self.maxpool(x)
x = self.relu(self.bn5(self.conv5(x)))
x = self.maxpool(x)
x = self.relu(self.bn6(self.conv6(x)))
x = self.maxpool(x)
x = self.relu(self.bn7(self.conv7(x)))
x = self.maxpool(x)
x = self.relu(self.bn8(self.conv8(x)))
x = self.maxpool(x)
x = Flatten()(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.relu(self.fc1(x))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.relu(self.fc2(x))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc3(x)
return x