From a701c9b32126f1e6974d9fcb3a5c3700527d8559 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 11 Oct 2019 16:05:30 -0400 Subject: [PATCH] CTRL to tf automodels --- transformers/modeling_tf_auto.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/transformers/modeling_tf_auto.py b/transformers/modeling_tf_auto.py index a8f1de047f..df0ad6e401 100644 --- a/transformers/modeling_tf_auto.py +++ b/transformers/modeling_tf_auto.py @@ -26,6 +26,7 @@ from .modeling_tf_xlnet import TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSeque from .modeling_tf_xlm import TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple from .modeling_tf_roberta import TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification from .modeling_tf_distilbert import TFDistilBertModel, TFDistilBertForQuestionAnswering, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification +from .modeling_tf_ctrl import TFCTRLModel, TFCTRLLMHeadModel from .file_utils import add_start_docstrings @@ -52,6 +53,7 @@ class TFAutoModel(object): - contains `transfo-xl`: TFTransfoXLModel (Transformer-XL model) - contains `xlnet`: TFXLNetModel (XLNet model) - contains `xlm`: TFXLMModel (XLM model) + - contains `ctrl`: TFCTRLModel (CTRL model) This class cannot be instantiated using `__init__()` (throws an error). """ @@ -73,7 +75,7 @@ class TFAutoModel(object): - contains `gpt2`: TFGPT2Model (OpenAI GPT-2 model) - contains `transfo-xl`: TFTransfoXLModel (Transformer-XL model) - contains `xlnet`: TFXLNetModel (XLNet model) - - contains `xlm`: TFXLMModel (XLM model) + - contains `ctrl`: TFCTRLModel (CTRL model) Params: pretrained_model_name_or_path: either: @@ -147,10 +149,12 @@ class TFAutoModel(object): return TFXLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'xlm' in pretrained_model_name_or_path: return TFXLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'ctrl' in pretrained_model_name_or_path: + return TFCTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) raise ValueError("Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " - "'xlm', 'roberta'".format(pretrained_model_name_or_path)) + "'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path)) class TFAutoModelWithLMHead(object): @@ -173,6 +177,7 @@ class TFAutoModelWithLMHead(object): - contains `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model) - contains `xlnet`: TFXLNetLMHeadModel (XLNet model) - contains `xlm`: TFXLMWithLMHeadModel (XLM model) + - contains `ctrl`: TFCTRLLMHeadModel (CTRL model) This class cannot be instantiated using `__init__()` (throws an error). """ @@ -198,6 +203,7 @@ class TFAutoModelWithLMHead(object): - contains `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model) - contains `xlnet`: TFXLNetLMHeadModel (XLNet model) - contains `xlm`: TFXLMWithLMHeadModel (XLM model) + - contains `ctrl`: TFCTRLLMHeadModel (CTRL model) Params: pretrained_model_name_or_path: either: @@ -271,10 +277,12 @@ class TFAutoModelWithLMHead(object): return TFXLNetLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'xlm' in pretrained_model_name_or_path: return TFXLMWithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'ctrl' in pretrained_model_name_or_path: + return TFCTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) raise ValueError("Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " - "'xlm', 'roberta'".format(pretrained_model_name_or_path)) + "'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path)) class TFAutoModelForSequenceClassification(object):