From 8ba44ced95a4b2d9e3680ed8ea84bb13d06b2dd9 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 24 Sep 2019 09:48:23 +0200 Subject: [PATCH] fix roberta conversion script --- pytorch_transformers/__init__.py | 2 +- pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index befe5710e8..e12f4b2650 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -138,7 +138,7 @@ if _tf_available: TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_tf_roberta import (TFRobertaPreTrainedModel, TFRobertaMainLayer, - TFRobertaModel, TFRobertaLMHead, + TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) diff --git a/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py b/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py index 52a65ab607..567f1d0b5b 100644 --- a/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py @@ -30,7 +30,7 @@ from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, - RobertaConfig, TFRobertaLMHead, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, + RobertaConfig, TFRobertaForMaskedLM, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP) if is_torch_available(): @@ -73,7 +73,7 @@ MODEL_CLASSES = { 'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP), 'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP), 'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP), - 'roberta': (RobertaConfig, TFRobertaLMHead, load_roberta_pt_weights_in_tf2, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), + 'roberta': (RobertaConfig, TFRobertaForMaskedLM, load_roberta_pt_weights_in_tf2, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP), 'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP), }