From 48b438ff2a9a6b828b083e6ee7e09c74004d7279 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 9 Oct 2019 17:06:30 +0200 Subject: [PATCH] doc and conversion --- docs/source/pretrained_models.rst | 3 +++ transformers/__init__.py | 5 +++++ transformers/convert_pytorch_checkpoint_to_tf2.py | 14 ++++++++++---- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index 4c17b35c84..2ff73184c9 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -119,5 +119,8 @@ Here is the full list of the currently provided pretrained models together with | | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint, with an additional linear layer. | | | | (see `details `__) | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| CTRL | ``ctrl`` | | 48-layer, 1280-hidden, 16-heads, 1.6B parameters | +| | | | Salesforce's Large-sized CTRL English model | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ .. `__ \ No newline at end of file diff --git a/transformers/__init__.py b/transformers/__init__.py index 3a6d2aee2c..3d778a4941 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -155,6 +155,11 @@ if is_tf_available(): load_distilbert_pt_weights_in_tf2, TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) + from .modeling_tf_ctrl import (TFCTRLPreTrainedModel, TFCTRLModel, + TFCTRLLMHeadModel, + load_ctrl_pt_weights_in_tf2, + TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) + # TF 2.0 <=> PyTorch conversion utilities if is_tf_available() and is_torch_available(): from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name, diff --git a/transformers/convert_pytorch_checkpoint_to_tf2.py b/transformers/convert_pytorch_checkpoint_to_tf2.py index c5f7650b50..c90f581bba 100644 --- a/transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/transformers/convert_pytorch_checkpoint_to_tf2.py @@ -31,7 +31,8 @@ from transformers import (BertConfig, TFBertForPreTraining, TFBertForQuestionAns TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, - DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP) + DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, + CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) if is_torch_available(): import torch @@ -43,7 +44,8 @@ if is_torch_available(): TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, - DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) + DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, + CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) else: (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, @@ -52,7 +54,8 @@ else: TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, - DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,) = ( + DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, + CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) = ( None, None, None, None, None, None, None, None, @@ -60,7 +63,8 @@ else: None, None, None, None, None, None, None, - None, None, None,) + None, None, None, + None, None) import logging @@ -80,6 +84,7 @@ MODEL_CLASSES = { 'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'ctrl': (CTRLConfig, TFCTRLLMHeadModel, load_ctrl_pt_weights_in_tf2, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP) } def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True): @@ -228,6 +233,7 @@ if __name__ == "__main__": convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None, args.tf_dump_path, model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None, + config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None, compare_with_pt_model=args.compare_with_pt_model, use_cached_models=args.use_cached_models, only_convert_finetuned_models=args.only_convert_finetuned_models)