From f4432b7e01dc46008fb823096d884bdc2861b49c Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 21 Dec 2020 19:56:46 +0530 Subject: [PATCH] add base model classes to bart subclassed models (#9230) * add base model classes to bart subclassed models * add doc --- docs/source/model_doc/blenderbot.rst | 9 +++++++ docs/source/model_doc/mbart.rst | 7 +++++ docs/source/model_doc/pegasus.rst | 6 +++++ src/transformers/__init__.py | 10 ++++--- src/transformers/models/auto/modeling_auto.py | 10 ++++--- .../models/blenderbot/__init__.py | 6 ++++- .../models/blenderbot/modeling_blenderbot.py | 17 ++++++++++-- src/transformers/models/mbart/__init__.py | 2 +- .../models/mbart/modeling_mbart.py | 19 ++++++++++++- src/transformers/models/pegasus/__init__.py | 2 +- .../models/pegasus/modeling_pegasus.py | 26 +++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 27 +++++++++++++++++++ tests/test_modeling_blenderbot.py | 3 ++- tests/test_modeling_mbart.py | 3 ++- tests/test_modeling_pegasus.py | 4 +-- 15 files changed, 134 insertions(+), 17 deletions(-) diff --git a/docs/source/model_doc/blenderbot.rst b/docs/source/model_doc/blenderbot.rst index ddceeb81c1..df43c90ef0 100644 --- a/docs/source/model_doc/blenderbot.rst +++ b/docs/source/model_doc/blenderbot.rst @@ -100,6 +100,15 @@ BlenderbotSmallTokenizer :members: +BlenderbotModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +See :obj:`transformers.BartModel` for arguments to `forward` and `generate` + +.. autoclass:: transformers.BlenderbotModel + :members: + + BlenderbotForConditionalGeneration ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/mbart.rst b/docs/source/model_doc/mbart.rst index eb9b979802..4ac391255e 100644 --- a/docs/source/model_doc/mbart.rst +++ b/docs/source/model_doc/mbart.rst @@ -97,6 +97,13 @@ MBartTokenizerFast :members: +MBartModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MBartModel + :members: + + MBartForConditionalGeneration ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/pegasus.rst b/docs/source/model_doc/pegasus.rst index 42b3e5ea57..3fab320ebc 100644 --- a/docs/source/model_doc/pegasus.rst +++ b/docs/source/model_doc/pegasus.rst @@ -119,6 +119,12 @@ PegasusTokenizerFast :members: +PegasusModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.PegasusModel + + PegasusForConditionalGeneration ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4586fe5363..580318abaa 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -406,7 +406,11 @@ if is_torch_available(): BertGenerationEncoder, load_tf_weights_in_bert_generation, ) - from .models.blenderbot import BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForConditionalGeneration + from .models.blenderbot import ( + BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, + BlenderbotForConditionalGeneration, + BlenderbotModel, + ) from .models.camembert import ( CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, CamembertForCausalLM, @@ -522,7 +526,7 @@ if is_torch_available(): LxmertXLayer, ) from .models.marian import MarianMTModel - from .models.mbart import MBartForConditionalGeneration + from .models.mbart import MBartForConditionalGeneration, MBartModel from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings from .models.mobilebert import ( MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -559,7 +563,7 @@ if is_torch_available(): OpenAIGPTPreTrainedModel, load_tf_weights_in_openai_gpt, ) - from .models.pegasus import PegasusForConditionalGeneration + from .models.pegasus import PegasusForConditionalGeneration, PegasusModel from .models.prophetnet import ( PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, ProphetNetDecoder, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 4b9141d024..3fc5c702e7 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -50,7 +50,7 @@ from ..bert.modeling_bert import ( BertModel, ) from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder -from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration +from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration, BlenderbotModel from ..camembert.modeling_camembert import ( CamembertForCausalLM, CamembertForMaskedLM, @@ -111,7 +111,7 @@ from ..longformer.modeling_longformer import ( ) from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel from ..marian.modeling_marian import MarianMTModel -from ..mbart.modeling_mbart import MBartForConditionalGeneration +from ..mbart.modeling_mbart import MBartForConditionalGeneration, MBartModel from ..mobilebert.modeling_mobilebert import ( MobileBertForMaskedLM, MobileBertForMultipleChoice, @@ -132,7 +132,7 @@ from ..mpnet.modeling_mpnet import ( ) from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel -from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration +from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration, PegasusModel from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel from ..rag.modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function RagModel, @@ -255,6 +255,10 @@ MODEL_MAPPING = OrderedDict( (RetriBertConfig, RetriBertModel), (MT5Config, MT5Model), (T5Config, T5Model), + (PegasusConfig, PegasusModel), + (MarianConfig, MarianMTModel), + (MBartConfig, MBartModel), + (BlenderbotConfig, BlenderbotModel), (DistilBertConfig, DistilBertModel), (AlbertConfig, AlbertModel), (CamembertConfig, CamembertModel), diff --git a/src/transformers/models/blenderbot/__init__.py b/src/transformers/models/blenderbot/__init__.py index fdcd990ff9..fccb38f80a 100644 --- a/src/transformers/models/blenderbot/__init__.py +++ b/src/transformers/models/blenderbot/__init__.py @@ -22,7 +22,11 @@ from .tokenization_blenderbot import BlenderbotSmallTokenizer, BlenderbotTokeniz if is_torch_available(): - from .modeling_blenderbot import BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForConditionalGeneration + from .modeling_blenderbot import ( + BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, + BlenderbotForConditionalGeneration, + BlenderbotModel, + ) if is_tf_available(): from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 1421a87ca9..2a370fbabf 100644 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -19,7 +19,7 @@ import torch from ...file_utils import add_start_docstrings -from ..bart.modeling_bart import BartForConditionalGeneration +from ..bart.modeling_bart import BartForConditionalGeneration, BartModel from .configuration_blenderbot import BlenderbotConfig @@ -39,7 +39,20 @@ BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = ["facebook/blenderbot-3B", "facebook/ @add_start_docstrings( - "The BART Model with a language modeling head. Can be used for summarization.", BLENDER_START_DOCSTRING + "The bare BlenderBot Model transformer outputting raw hidden-states without any specific head on top.", + BLENDER_START_DOCSTRING, +) +class BlenderbotModel(BartModel): + r""" + This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + config_class = BlenderbotConfig + + +@add_start_docstrings( + "The BlenderBot Model with a language modeling head. Can be used for summarization.", BLENDER_START_DOCSTRING ) class BlenderbotForConditionalGeneration(BartForConditionalGeneration): """ diff --git a/src/transformers/models/mbart/__init__.py b/src/transformers/models/mbart/__init__.py index b98d226625..2fa8876085 100644 --- a/src/transformers/models/mbart/__init__.py +++ b/src/transformers/models/mbart/__init__.py @@ -27,7 +27,7 @@ if is_tokenizers_available(): from .tokenization_mbart_fast import MBartTokenizerFast if is_torch_available(): - from .modeling_mbart import MBartForConditionalGeneration + from .modeling_mbart import MBartForConditionalGeneration, MBartModel if is_tf_available(): from .modeling_tf_mbart import TFMBartForConditionalGeneration diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 9fca52c549..f4aa39b075 100644 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..bart.modeling_bart import BartForConditionalGeneration +from ..bart.modeling_bart import BartForConditionalGeneration, BartModel from .configuration_mbart import MBartConfig @@ -26,6 +26,23 @@ MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] +class MBartModel(BartModel): + r""" + This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + config_class = MBartConfig + _keys_to_ignore_on_load_missing = [ + "encoder.embed_positions.weight", + "decoder.embed_positions.weight", + ] + _keys_to_ignore_on_save = [ + "encoder.embed_positions.weight", + "decoder.embed_positions.weight", + ] + + class MBartForConditionalGeneration(BartForConditionalGeneration): r""" This class overrides :class:`~transformers.BartForConditionalGeneration`. Please check the superclass for the diff --git a/src/transformers/models/pegasus/__init__.py b/src/transformers/models/pegasus/__init__.py index e7cc0ce71b..20d1c3872d 100644 --- a/src/transformers/models/pegasus/__init__.py +++ b/src/transformers/models/pegasus/__init__.py @@ -27,7 +27,7 @@ if is_tokenizers_available(): from .tokenization_pegasus_fast import PegasusTokenizerFast if is_torch_available(): - from .modeling_pegasus import PegasusForConditionalGeneration + from .modeling_pegasus import PegasusForConditionalGeneration, PegasusModel if is_tf_available(): from .modeling_tf_pegasus import TFPegasusForConditionalGeneration diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 3e623a7704..c7fde41643 100644 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -16,10 +16,34 @@ from ...file_utils import add_start_docstrings -from ..bart.modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration +from ..bart.modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration, BartModel from .configuration_pegasus import PegasusConfig +@add_start_docstrings( + "The bare Pegasus Model transformer outputting raw hidden-states without any specific head on top.", + BART_START_DOCSTRING, +) +class PegasusModel(BartModel): + r""" + This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + config_class = PegasusConfig + _keys_to_ignore_on_load_missing = [ + r"final_logits_bias", + r"encoder\.version", + r"decoder\.version", + "encoder.embed_positions", + "decoder.embed_positions", + ] + _keys_to_ignore_on_save = [ + "encoder.embed_positions.weight", + "decoder.embed_positions.weight", + ] + + @add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING) class PegasusForConditionalGeneration(BartForConditionalGeneration): r""" diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 050c7ba4f9..97669eff74 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -600,6 +600,15 @@ class BlenderbotForConditionalGeneration: requires_pytorch(self) +class BlenderbotModel: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -1297,6 +1306,15 @@ class MBartForConditionalGeneration: requires_pytorch(self) +class MBartModel: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + class MMBTForClassification: def __init__(self, *args, **kwargs): requires_pytorch(self) @@ -1560,6 +1578,15 @@ class PegasusForConditionalGeneration: requires_pytorch(self) +class PegasusModel: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/test_modeling_blenderbot.py b/tests/test_modeling_blenderbot.py index b069ba6089..668569a595 100644 --- a/tests/test_modeling_blenderbot.py +++ b/tests/test_modeling_blenderbot.py @@ -32,6 +32,7 @@ if is_torch_available(): AutoTokenizer, BlenderbotConfig, BlenderbotForConditionalGeneration, + BlenderbotModel, BlenderbotSmallTokenizer, BlenderbotTokenizer, ) @@ -90,7 +91,7 @@ class BlenderbotModelTester: class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase): if is_torch_available(): all_generative_model_classes = (BlenderbotForConditionalGeneration,) - all_model_classes = (BlenderbotForConditionalGeneration,) + all_model_classes = (BlenderbotForConditionalGeneration, BlenderbotModel) else: all_generative_model_classes = () all_model_classes = () diff --git a/tests/test_modeling_mbart.py b/tests/test_modeling_mbart.py index 1a4094ed2c..2a43650feb 100644 --- a/tests/test_modeling_mbart.py +++ b/tests/test_modeling_mbart.py @@ -30,6 +30,7 @@ if is_torch_available(): BatchEncoding, MBartConfig, MBartForConditionalGeneration, + MBartModel, ) @@ -59,7 +60,7 @@ class ModelTester: @require_torch class SelectiveCommonTest(unittest.TestCase): - all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else () + all_model_classes = (MBartForConditionalGeneration, MBartModel) if is_torch_available() else () test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save diff --git a/tests/test_modeling_pegasus.py b/tests/test_modeling_pegasus.py index 42173ebccf..dc9fdf5225 100644 --- a/tests/test_modeling_pegasus.py +++ b/tests/test_modeling_pegasus.py @@ -26,7 +26,7 @@ from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest if is_torch_available(): - from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration + from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration, PegasusModel XSUM_ENTRY_LONGER = """ The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """ @@ -55,7 +55,7 @@ class ModelTester: @require_torch class SelectiveCommonTest(unittest.TestCase): - all_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else () + all_model_classes = (PegasusForConditionalGeneration, PegasusModel) if is_torch_available() else () test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save