FS-TFP/federatedscope/nlp/model/model_builder.py

46 lines
1.8 KiB
Python

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
def get_rnn(model_config, input_shape):
from federatedscope.nlp.model.rnn import LSTM
# check the task
# input_shape: (batch_size, seq_len, hidden) or (seq_len, hidden)
if model_config.type == 'lstm':
model = LSTM(
in_channels=input_shape[-2]
if not model_config.in_channels else model_config.in_channels,
hidden=model_config.hidden,
out_channels=model_config.out_channels,
embed_size=model_config.embed_size,
dropout=model_config.dropout)
else:
raise ValueError(f'No model named {model_config.type}!')
return model
def get_transformer(model_config, input_shape):
from transformers import AutoModelForPreTraining, \
AutoModelForQuestionAnswering, AutoModelForSequenceClassification, \
AutoModelForTokenClassification, AutoModelWithLMHead, AutoModel
model_func_dict = {
'PreTraining'.lower(): AutoModelForPreTraining,
'QuestionAnswering'.lower(): AutoModelForQuestionAnswering,
'SequenceClassification'.lower(): AutoModelForSequenceClassification,
'TokenClassification'.lower(): AutoModelForTokenClassification,
'WithLMHead'.lower(): AutoModelWithLMHead,
'Auto'.lower(): AutoModel
}
assert model_config.task.lower(
) in model_func_dict, f'model_config.task should be in' \
f' {model_func_dict.keys()} ' \
f'when using pre_trained transformer model '
path, _ = model_config.type.split('@')
model = model_func_dict[model_config.task.lower()].from_pretrained(
path, num_labels=model_config.out_channels)
return model