75 lines
2.2 KiB
Python
75 lines
2.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from copy import deepcopy
|
|
|
|
|
|
class GeneratorFemnist(nn.Module):
|
|
'''
|
|
The generator for Femnist dataset
|
|
'''
|
|
def __init__(self, noise_dim=100):
|
|
super(GeneratorFemnist, self).__init__()
|
|
|
|
module_list = []
|
|
module_list.append(
|
|
nn.Linear(in_features=noise_dim,
|
|
out_features=4 * 4 * 256,
|
|
bias=False))
|
|
module_list.append(nn.BatchNorm1d(num_features=4 * 4 * 256))
|
|
module_list.append(nn.ReLU())
|
|
self.body1 = nn.Sequential(*module_list)
|
|
|
|
# need to reshape the output of self.body1
|
|
|
|
module_list = []
|
|
|
|
module_list.append(
|
|
nn.ConvTranspose2d(in_channels=256,
|
|
out_channels=128,
|
|
kernel_size=(3, 3),
|
|
stride=(1, 1),
|
|
bias=False))
|
|
module_list.append(nn.BatchNorm2d(128))
|
|
module_list.append(nn.ReLU())
|
|
self.body2 = nn.Sequential(*module_list)
|
|
|
|
module_list = []
|
|
module_list.append(
|
|
nn.ConvTranspose2d(in_channels=128,
|
|
out_channels=64,
|
|
kernel_size=(3, 3),
|
|
stride=(2, 2),
|
|
bias=False))
|
|
module_list.append(nn.BatchNorm2d(64))
|
|
module_list.append(nn.ReLU())
|
|
self.body3 = nn.Sequential(*module_list)
|
|
|
|
module_list = []
|
|
module_list.append(
|
|
nn.ConvTranspose2d(in_channels=64,
|
|
out_channels=1,
|
|
kernel_size=(4, 4),
|
|
stride=(2, 2),
|
|
bias=False))
|
|
module_list.append(nn.BatchNorm2d(1))
|
|
module_list.append(nn.Tanh())
|
|
self.body4 = nn.Sequential(*module_list)
|
|
|
|
def forward(self, x):
|
|
|
|
tmp1 = self.body1(x).view(-1, 256, 4, 4)
|
|
|
|
assert tmp1.size()[1:4] == (256, 4, 4)
|
|
|
|
tmp2 = self.body2(tmp1)
|
|
assert tmp2.size()[1:4] == (128, 6, 6)
|
|
|
|
tmp3 = self.body3(tmp2)
|
|
|
|
assert tmp3.size()[1:4] == (64, 13, 13)
|
|
|
|
tmp4 = self.body4(tmp3)
|
|
assert tmp4.size()[1:4] == (1, 28, 28)
|
|
|
|
return tmp4
|