FS-TFP/federatedscope/cv/model/model_builder.py

36 lines
1.4 KiB
Python

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
from federatedscope.cv.model.cnn import ConvNet2, ConvNet5, VGG11
def get_cnn(model_config, input_shape):
# check the task
# input_shape: (batch_size, in_channels, h, w) or (in_channels, h, w)
if model_config.type == 'convnet2':
model = ConvNet2(in_channels=input_shape[-3],
h=input_shape[-2],
w=input_shape[-1],
hidden=model_config.hidden,
class_num=model_config.out_channels,
dropout=model_config.dropout)
elif model_config.type == 'convnet5':
model = ConvNet5(in_channels=input_shape[-3],
h=input_shape[-2],
w=input_shape[-1],
hidden=model_config.hidden,
class_num=model_config.out_channels,
dropout=model_config.dropout)
elif model_config.type == 'vgg11':
model = VGG11(in_channels=input_shape[-3],
h=input_shape[-2],
w=input_shape[-1],
hidden=model_config.hidden,
class_num=model_config.out_channels,
dropout=model_config.dropout)
else:
raise ValueError(f'No model named {model_config.type}!')
return model