FS-TFP/federatedscope/attack/models/gan_based_model.py

75 lines
2.2 KiB
Python

import torch
import torch.nn as nn
from copy import deepcopy
class GeneratorFemnist(nn.Module):
'''
The generator for Femnist dataset
'''
def __init__(self, noise_dim=100):
super(GeneratorFemnist, self).__init__()
module_list = []
module_list.append(
nn.Linear(in_features=noise_dim,
out_features=4 * 4 * 256,
bias=False))
module_list.append(nn.BatchNorm1d(num_features=4 * 4 * 256))
module_list.append(nn.ReLU())
self.body1 = nn.Sequential(*module_list)
# need to reshape the output of self.body1
module_list = []
module_list.append(
nn.ConvTranspose2d(in_channels=256,
out_channels=128,
kernel_size=(3, 3),
stride=(1, 1),
bias=False))
module_list.append(nn.BatchNorm2d(128))
module_list.append(nn.ReLU())
self.body2 = nn.Sequential(*module_list)
module_list = []
module_list.append(
nn.ConvTranspose2d(in_channels=128,
out_channels=64,
kernel_size=(3, 3),
stride=(2, 2),
bias=False))
module_list.append(nn.BatchNorm2d(64))
module_list.append(nn.ReLU())
self.body3 = nn.Sequential(*module_list)
module_list = []
module_list.append(
nn.ConvTranspose2d(in_channels=64,
out_channels=1,
kernel_size=(4, 4),
stride=(2, 2),
bias=False))
module_list.append(nn.BatchNorm2d(1))
module_list.append(nn.Tanh())
self.body4 = nn.Sequential(*module_list)
def forward(self, x):
tmp1 = self.body1(x).view(-1, 256, 4, 4)
assert tmp1.size()[1:4] == (256, 4, 4)
tmp2 = self.body2(tmp1)
assert tmp2.size()[1:4] == (128, 6, 6)
tmp3 = self.body3(tmp2)
assert tmp3.size()[1:4] == (64, 13, 13)
tmp4 = self.body4(tmp3)
assert tmp4.size()[1:4] == (1, 28, 28)
return tmp4