add base model classes to bart subclassed models (#9230)
* add base model classes to bart subclassed models * add doc
This commit is contained in:
@@ -406,7 +406,11 @@ if is_torch_available():
|
||||
BertGenerationEncoder,
|
||||
load_tf_weights_in_bert_generation,
|
||||
)
|
||||
from .models.blenderbot import BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForConditionalGeneration
|
||||
from .models.blenderbot import (
|
||||
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
BlenderbotForConditionalGeneration,
|
||||
BlenderbotModel,
|
||||
)
|
||||
from .models.camembert import (
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
CamembertForCausalLM,
|
||||
@@ -522,7 +526,7 @@ if is_torch_available():
|
||||
LxmertXLayer,
|
||||
)
|
||||
from .models.marian import MarianMTModel
|
||||
from .models.mbart import MBartForConditionalGeneration
|
||||
from .models.mbart import MBartForConditionalGeneration, MBartModel
|
||||
from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
|
||||
from .models.mobilebert import (
|
||||
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@@ -559,7 +563,7 @@ if is_torch_available():
|
||||
OpenAIGPTPreTrainedModel,
|
||||
load_tf_weights_in_openai_gpt,
|
||||
)
|
||||
from .models.pegasus import PegasusForConditionalGeneration
|
||||
from .models.pegasus import PegasusForConditionalGeneration, PegasusModel
|
||||
from .models.prophetnet import (
|
||||
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
ProphetNetDecoder,
|
||||
|
||||
@@ -50,7 +50,7 @@ from ..bert.modeling_bert import (
|
||||
BertModel,
|
||||
)
|
||||
from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder
|
||||
from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration
|
||||
from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration, BlenderbotModel
|
||||
from ..camembert.modeling_camembert import (
|
||||
CamembertForCausalLM,
|
||||
CamembertForMaskedLM,
|
||||
@@ -111,7 +111,7 @@ from ..longformer.modeling_longformer import (
|
||||
)
|
||||
from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel
|
||||
from ..marian.modeling_marian import MarianMTModel
|
||||
from ..mbart.modeling_mbart import MBartForConditionalGeneration
|
||||
from ..mbart.modeling_mbart import MBartForConditionalGeneration, MBartModel
|
||||
from ..mobilebert.modeling_mobilebert import (
|
||||
MobileBertForMaskedLM,
|
||||
MobileBertForMultipleChoice,
|
||||
@@ -132,7 +132,7 @@ from ..mpnet.modeling_mpnet import (
|
||||
)
|
||||
from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model
|
||||
from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||
from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration
|
||||
from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration, PegasusModel
|
||||
from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel
|
||||
from ..rag.modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function
|
||||
RagModel,
|
||||
@@ -255,6 +255,10 @@ MODEL_MAPPING = OrderedDict(
|
||||
(RetriBertConfig, RetriBertModel),
|
||||
(MT5Config, MT5Model),
|
||||
(T5Config, T5Model),
|
||||
(PegasusConfig, PegasusModel),
|
||||
(MarianConfig, MarianMTModel),
|
||||
(MBartConfig, MBartModel),
|
||||
(BlenderbotConfig, BlenderbotModel),
|
||||
(DistilBertConfig, DistilBertModel),
|
||||
(AlbertConfig, AlbertModel),
|
||||
(CamembertConfig, CamembertModel),
|
||||
|
||||
@@ -22,7 +22,11 @@ from .tokenization_blenderbot import BlenderbotSmallTokenizer, BlenderbotTokeniz
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_blenderbot import BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForConditionalGeneration
|
||||
from .modeling_blenderbot import (
|
||||
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
BlenderbotForConditionalGeneration,
|
||||
BlenderbotModel,
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
import torch
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ..bart.modeling_bart import BartForConditionalGeneration
|
||||
from ..bart.modeling_bart import BartForConditionalGeneration, BartModel
|
||||
from .configuration_blenderbot import BlenderbotConfig
|
||||
|
||||
|
||||
@@ -39,7 +39,20 @@ BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = ["facebook/blenderbot-3B", "facebook/
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The BART Model with a language modeling head. Can be used for summarization.", BLENDER_START_DOCSTRING
|
||||
"The bare BlenderBot Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BLENDER_START_DOCSTRING,
|
||||
)
|
||||
class BlenderbotModel(BartModel):
|
||||
r"""
|
||||
This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate
|
||||
documentation alongside usage examples.
|
||||
"""
|
||||
|
||||
config_class = BlenderbotConfig
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The BlenderBot Model with a language modeling head. Can be used for summarization.", BLENDER_START_DOCSTRING
|
||||
)
|
||||
class BlenderbotForConditionalGeneration(BartForConditionalGeneration):
|
||||
"""
|
||||
|
||||
@@ -27,7 +27,7 @@ if is_tokenizers_available():
|
||||
from .tokenization_mbart_fast import MBartTokenizerFast
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_mbart import MBartForConditionalGeneration
|
||||
from .modeling_mbart import MBartForConditionalGeneration, MBartModel
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_mbart import TFMBartForConditionalGeneration
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..bart.modeling_bart import BartForConditionalGeneration
|
||||
from ..bart.modeling_bart import BartForConditionalGeneration, BartModel
|
||||
from .configuration_mbart import MBartConfig
|
||||
|
||||
|
||||
@@ -26,6 +26,23 @@ MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
class MBartModel(BartModel):
|
||||
r"""
|
||||
This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate
|
||||
documentation alongside usage examples.
|
||||
"""
|
||||
|
||||
config_class = MBartConfig
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
"encoder.embed_positions.weight",
|
||||
"decoder.embed_positions.weight",
|
||||
]
|
||||
_keys_to_ignore_on_save = [
|
||||
"encoder.embed_positions.weight",
|
||||
"decoder.embed_positions.weight",
|
||||
]
|
||||
|
||||
|
||||
class MBartForConditionalGeneration(BartForConditionalGeneration):
|
||||
r"""
|
||||
This class overrides :class:`~transformers.BartForConditionalGeneration`. Please check the superclass for the
|
||||
|
||||
@@ -27,7 +27,7 @@ if is_tokenizers_available():
|
||||
from .tokenization_pegasus_fast import PegasusTokenizerFast
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_pegasus import PegasusForConditionalGeneration
|
||||
from .modeling_pegasus import PegasusForConditionalGeneration, PegasusModel
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration
|
||||
|
||||
@@ -16,10 +16,34 @@
|
||||
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ..bart.modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
|
||||
from ..bart.modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration, BartModel
|
||||
from .configuration_pegasus import PegasusConfig
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Pegasus Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BART_START_DOCSTRING,
|
||||
)
|
||||
class PegasusModel(BartModel):
|
||||
r"""
|
||||
This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate
|
||||
documentation alongside usage examples.
|
||||
"""
|
||||
|
||||
config_class = PegasusConfig
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"final_logits_bias",
|
||||
r"encoder\.version",
|
||||
r"decoder\.version",
|
||||
"encoder.embed_positions",
|
||||
"decoder.embed_positions",
|
||||
]
|
||||
_keys_to_ignore_on_save = [
|
||||
"encoder.embed_positions.weight",
|
||||
"decoder.embed_positions.weight",
|
||||
]
|
||||
|
||||
|
||||
@add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING)
|
||||
class PegasusForConditionalGeneration(BartForConditionalGeneration):
|
||||
r"""
|
||||
|
||||
@@ -600,6 +600,15 @@ class BlenderbotForConditionalGeneration:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class BlenderbotModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@@ -1297,6 +1306,15 @@ class MBartForConditionalGeneration:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class MBartModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class MMBTForClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
@@ -1560,6 +1578,15 @@ class PegasusForConditionalGeneration:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class PegasusModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user