36 lines
1.4 KiB
Python
36 lines
1.4 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import print_function
|
|
from __future__ import division
|
|
|
|
from federatedscope.cv.model.cnn import ConvNet2, ConvNet5, VGG11
|
|
|
|
|
|
def get_cnn(model_config, input_shape):
|
|
# check the task
|
|
# input_shape: (batch_size, in_channels, h, w) or (in_channels, h, w)
|
|
if model_config.type == 'convnet2':
|
|
model = ConvNet2(in_channels=input_shape[-3],
|
|
h=input_shape[-2],
|
|
w=input_shape[-1],
|
|
hidden=model_config.hidden,
|
|
class_num=model_config.out_channels,
|
|
dropout=model_config.dropout)
|
|
elif model_config.type == 'convnet5':
|
|
model = ConvNet5(in_channels=input_shape[-3],
|
|
h=input_shape[-2],
|
|
w=input_shape[-1],
|
|
hidden=model_config.hidden,
|
|
class_num=model_config.out_channels,
|
|
dropout=model_config.dropout)
|
|
elif model_config.type == 'vgg11':
|
|
model = VGG11(in_channels=input_shape[-3],
|
|
h=input_shape[-2],
|
|
w=input_shape[-1],
|
|
hidden=model_config.hidden,
|
|
class_num=model_config.out_channels,
|
|
dropout=model_config.dropout)
|
|
else:
|
|
raise ValueError(f'No model named {model_config.type}!')
|
|
|
|
return model
|