from __future__ import absolute_import from __future__ import print_function from __future__ import division def get_rnn(model_config, input_shape): from federatedscope.nlp.model.rnn import LSTM # check the task # input_shape: (batch_size, seq_len, hidden) or (seq_len, hidden) if model_config.type == 'lstm': model = LSTM( in_channels=input_shape[-2] if not model_config.in_channels else model_config.in_channels, hidden=model_config.hidden, out_channels=model_config.out_channels, embed_size=model_config.embed_size, dropout=model_config.dropout) else: raise ValueError(f'No model named {model_config.type}!') return model def get_transformer(model_config, input_shape): from transformers import AutoModelForPreTraining, \ AutoModelForQuestionAnswering, AutoModelForSequenceClassification, \ AutoModelForTokenClassification, AutoModelWithLMHead, AutoModel model_func_dict = { 'PreTraining'.lower(): AutoModelForPreTraining, 'QuestionAnswering'.lower(): AutoModelForQuestionAnswering, 'SequenceClassification'.lower(): AutoModelForSequenceClassification, 'TokenClassification'.lower(): AutoModelForTokenClassification, 'WithLMHead'.lower(): AutoModelWithLMHead, 'Auto'.lower(): AutoModel } assert model_config.task.lower( ) in model_func_dict, f'model_config.task should be in' \ f' {model_func_dict.keys()} ' \ f'when using pre_trained transformer model ' path, _ = model_config.type.split('@') model = model_func_dict[model_config.task.lower()].from_pretrained( path, num_labels=model_config.out_channels) return model