from __future__ import absolute_import from __future__ import print_function from __future__ import division from federatedscope.cv.model.cnn import ConvNet2, ConvNet5, VGG11 def get_cnn(model_config, input_shape): # check the task # input_shape: (batch_size, in_channels, h, w) or (in_channels, h, w) if model_config.type == 'convnet2': model = ConvNet2(in_channels=input_shape[-3], h=input_shape[-2], w=input_shape[-1], hidden=model_config.hidden, class_num=model_config.out_channels, dropout=model_config.dropout) elif model_config.type == 'convnet5': model = ConvNet5(in_channels=input_shape[-3], h=input_shape[-2], w=input_shape[-1], hidden=model_config.hidden, class_num=model_config.out_channels, dropout=model_config.dropout) elif model_config.type == 'vgg11': model = VGG11(in_channels=input_shape[-3], h=input_shape[-2], w=input_shape[-1], hidden=model_config.hidden, class_num=model_config.out_channels, dropout=model_config.dropout) else: raise ValueError(f'No model named {model_config.type}!') return model