add base model classes to bart subclassed models (#9230)

* add base model classes to  bart subclassed models

* add doc
This commit is contained in:
Suraj Patil
2020-12-21 19:56:46 +05:30
committed by GitHub
parent 08abdabda1
commit f4432b7e01
15 changed files with 134 additions and 17 deletions

View File

@@ -406,7 +406,11 @@ if is_torch_available():
BertGenerationEncoder,
load_tf_weights_in_bert_generation,
)
from .models.blenderbot import BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForConditionalGeneration
from .models.blenderbot import (
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotForConditionalGeneration,
BlenderbotModel,
)
from .models.camembert import (
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
CamembertForCausalLM,
@@ -522,7 +526,7 @@ if is_torch_available():
LxmertXLayer,
)
from .models.marian import MarianMTModel
from .models.mbart import MBartForConditionalGeneration
from .models.mbart import MBartForConditionalGeneration, MBartModel
from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
from .models.mobilebert import (
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -559,7 +563,7 @@ if is_torch_available():
OpenAIGPTPreTrainedModel,
load_tf_weights_in_openai_gpt,
)
from .models.pegasus import PegasusForConditionalGeneration
from .models.pegasus import PegasusForConditionalGeneration, PegasusModel
from .models.prophetnet import (
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ProphetNetDecoder,

View File

@@ -50,7 +50,7 @@ from ..bert.modeling_bert import (
BertModel,
)
from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder
from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration
from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration, BlenderbotModel
from ..camembert.modeling_camembert import (
CamembertForCausalLM,
CamembertForMaskedLM,
@@ -111,7 +111,7 @@ from ..longformer.modeling_longformer import (
)
from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel
from ..marian.modeling_marian import MarianMTModel
from ..mbart.modeling_mbart import MBartForConditionalGeneration
from ..mbart.modeling_mbart import MBartForConditionalGeneration, MBartModel
from ..mobilebert.modeling_mobilebert import (
MobileBertForMaskedLM,
MobileBertForMultipleChoice,
@@ -132,7 +132,7 @@ from ..mpnet.modeling_mpnet import (
)
from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model
from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel
from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration
from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration, PegasusModel
from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel
from ..rag.modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function
RagModel,
@@ -255,6 +255,10 @@ MODEL_MAPPING = OrderedDict(
(RetriBertConfig, RetriBertModel),
(MT5Config, MT5Model),
(T5Config, T5Model),
(PegasusConfig, PegasusModel),
(MarianConfig, MarianMTModel),
(MBartConfig, MBartModel),
(BlenderbotConfig, BlenderbotModel),
(DistilBertConfig, DistilBertModel),
(AlbertConfig, AlbertModel),
(CamembertConfig, CamembertModel),

View File

@@ -22,7 +22,11 @@ from .tokenization_blenderbot import BlenderbotSmallTokenizer, BlenderbotTokeniz
if is_torch_available():
from .modeling_blenderbot import BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForConditionalGeneration
from .modeling_blenderbot import (
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotForConditionalGeneration,
BlenderbotModel,
)
if is_tf_available():
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration

View File

@@ -19,7 +19,7 @@
import torch
from ...file_utils import add_start_docstrings
from ..bart.modeling_bart import BartForConditionalGeneration
from ..bart.modeling_bart import BartForConditionalGeneration, BartModel
from .configuration_blenderbot import BlenderbotConfig
@@ -39,7 +39,20 @@ BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = ["facebook/blenderbot-3B", "facebook/
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BLENDER_START_DOCSTRING
"The bare BlenderBot Model transformer outputting raw hidden-states without any specific head on top.",
BLENDER_START_DOCSTRING,
)
class BlenderbotModel(BartModel):
r"""
This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate
documentation alongside usage examples.
"""
config_class = BlenderbotConfig
@add_start_docstrings(
"The BlenderBot Model with a language modeling head. Can be used for summarization.", BLENDER_START_DOCSTRING
)
class BlenderbotForConditionalGeneration(BartForConditionalGeneration):
"""

View File

@@ -27,7 +27,7 @@ if is_tokenizers_available():
from .tokenization_mbart_fast import MBartTokenizerFast
if is_torch_available():
from .modeling_mbart import MBartForConditionalGeneration
from .modeling_mbart import MBartForConditionalGeneration, MBartModel
if is_tf_available():
from .modeling_tf_mbart import TFMBartForConditionalGeneration

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..bart.modeling_bart import BartForConditionalGeneration
from ..bart.modeling_bart import BartForConditionalGeneration, BartModel
from .configuration_mbart import MBartConfig
@@ -26,6 +26,23 @@ MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
class MBartModel(BartModel):
r"""
This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate
documentation alongside usage examples.
"""
config_class = MBartConfig
_keys_to_ignore_on_load_missing = [
"encoder.embed_positions.weight",
"decoder.embed_positions.weight",
]
_keys_to_ignore_on_save = [
"encoder.embed_positions.weight",
"decoder.embed_positions.weight",
]
class MBartForConditionalGeneration(BartForConditionalGeneration):
r"""
This class overrides :class:`~transformers.BartForConditionalGeneration`. Please check the superclass for the

View File

@@ -27,7 +27,7 @@ if is_tokenizers_available():
from .tokenization_pegasus_fast import PegasusTokenizerFast
if is_torch_available():
from .modeling_pegasus import PegasusForConditionalGeneration
from .modeling_pegasus import PegasusForConditionalGeneration, PegasusModel
if is_tf_available():
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration

View File

@@ -16,10 +16,34 @@
from ...file_utils import add_start_docstrings
from ..bart.modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
from ..bart.modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration, BartModel
from .configuration_pegasus import PegasusConfig
@add_start_docstrings(
"The bare Pegasus Model transformer outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
)
class PegasusModel(BartModel):
r"""
This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate
documentation alongside usage examples.
"""
config_class = PegasusConfig
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"encoder\.version",
r"decoder\.version",
"encoder.embed_positions",
"decoder.embed_positions",
]
_keys_to_ignore_on_save = [
"encoder.embed_positions.weight",
"decoder.embed_positions.weight",
]
@add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING)
class PegasusForConditionalGeneration(BartForConditionalGeneration):
r"""

View File

@@ -600,6 +600,15 @@ class BlenderbotForConditionalGeneration:
requires_pytorch(self)
class BlenderbotModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1297,6 +1306,15 @@ class MBartForConditionalGeneration:
requires_pytorch(self)
class MBartModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class MMBTForClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@@ -1560,6 +1578,15 @@ class PegasusForConditionalGeneration:
requires_pytorch(self)
class PegasusModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None