Fix model templates (#9999)
This commit is contained in:
@@ -1656,7 +1656,7 @@ class BartForCausalLM(BartPretrainedModel):
|
|||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
provide it.
|
provide it.
|
||||||
|
|
||||||
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
|
Indices can be obtained using :class:`~transformers.BartTokenizer`. See
|
||||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||||
for details.
|
for details.
|
||||||
|
|
||||||
|
|||||||
@@ -1425,7 +1425,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
|
|||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
provide it.
|
provide it.
|
||||||
|
|
||||||
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
|
Indices can be obtained using :class:`~transformers.BlenderbotTokenizer`. See
|
||||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||||
for details.
|
for details.
|
||||||
|
|
||||||
|
|||||||
@@ -1400,7 +1400,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
|
|||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
provide it.
|
provide it.
|
||||||
|
|
||||||
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
|
Indices can be obtained using :class:`~transformers.BlenderbotSmallTokenizer`. See
|
||||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||||
for details.
|
for details.
|
||||||
|
|
||||||
|
|||||||
@@ -1411,7 +1411,7 @@ class MarianForCausalLM(MarianPreTrainedModel):
|
|||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
provide it.
|
provide it.
|
||||||
|
|
||||||
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
|
Indices can be obtained using :class:`~transformers.MarianTokenizer`. See
|
||||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||||
for details.
|
for details.
|
||||||
|
|
||||||
|
|||||||
@@ -1658,7 +1658,7 @@ class MBartForCausalLM(MBartPreTrainedModel):
|
|||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
provide it.
|
provide it.
|
||||||
|
|
||||||
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
|
Indices can be obtained using :class:`~transformers.MBartTokenizer`. See
|
||||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||||
for details.
|
for details.
|
||||||
|
|
||||||
|
|||||||
@@ -1414,7 +1414,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
|
|||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
provide it.
|
provide it.
|
||||||
|
|
||||||
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
|
Indices can be obtained using :class:`~transformers.PegasusTokenizer`. See
|
||||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||||
for details.
|
for details.
|
||||||
|
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ if is_torch_available():
|
|||||||
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
|
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
|
||||||
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
|
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
|
||||||
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
|
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
|
||||||
"{{cookiecutter.camelcase_modelname}}Model",
|
"{{cookiecutter.camelcase_modelname}}Model",
|
||||||
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
|
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -114,6 +115,7 @@ if TYPE_CHECKING:
|
|||||||
from .modeling_{{cookiecutter.lowercase_modelname}} import (
|
from .modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||||
|
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||||
{{cookiecutter.camelcase_modelname}}Model,
|
{{cookiecutter.camelcase_modelname}}Model,
|
||||||
|
|||||||
@@ -1546,6 +1546,7 @@ from ...modeling_outputs import (
|
|||||||
Seq2SeqModelOutput,
|
Seq2SeqModelOutput,
|
||||||
Seq2SeqQuestionAnsweringModelOutput,
|
Seq2SeqQuestionAnsweringModelOutput,
|
||||||
Seq2SeqSequenceClassifierOutput,
|
Seq2SeqSequenceClassifierOutput,
|
||||||
|
CausalLMOutputWithCrossAttentions
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
@@ -1952,7 +1953,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}ClassificationHead with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}}
|
# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->{{cookiecutter.camelcase_modelname}}
|
||||||
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
|
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
|
||||||
"""Head for sentence-level classification tasks."""
|
"""Head for sentence-level classification tasks."""
|
||||||
|
|
||||||
@@ -3066,8 +3067,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
|
|||||||
encoder_attentions=outputs.encoder_attentions,
|
encoder_attentions=outputs.encoder_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}DecoderWrapper with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}}
|
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->{{cookiecutter.camelcase_modelname}}
|
||||||
class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcase_modelname}}PretrainedModel):
|
class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcase_modelname}}PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
||||||
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
||||||
@@ -3081,8 +3082,8 @@ class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcas
|
|||||||
return self.decoder(*args, **kwargs)
|
return self.decoder(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}ForCausalLM with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}}
|
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->{{cookiecutter.camelcase_modelname}}
|
||||||
class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_modelname}}PretrainedModel):
|
class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_modelname}}PreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
@@ -3199,8 +3200,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
|||||||
|
|
||||||
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}ForCausalLM
|
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}ForCausalLM
|
||||||
|
|
||||||
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
|
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('facebook/bart-large')
|
||||||
>>> model = {{cookiecutter.camelcase_modelname}}ForCausalLM.from_pretrained('{{cookiecutter.checkpoint_identifier}}', add_cross_attention=False)
|
>>> model = {{cookiecutter.camelcase_modelname}}ForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
|
||||||
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
||||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
|
|||||||
@@ -488,7 +488,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
|
|||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_generation_utils import GenerationTesterMixin
|
from .test_generation_utils import GenerationTesterMixin
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor, floats_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -498,6 +498,7 @@ if is_torch_available():
|
|||||||
{{cookiecutter.camelcase_modelname}}Config,
|
{{cookiecutter.camelcase_modelname}}Config,
|
||||||
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||||
|
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||||
{{cookiecutter.camelcase_modelname}}Model,
|
{{cookiecutter.camelcase_modelname}}Model,
|
||||||
{{cookiecutter.camelcase_modelname}}Tokenizer,
|
{{cookiecutter.camelcase_modelname}}Tokenizer,
|
||||||
|
|||||||
@@ -47,6 +47,7 @@
|
|||||||
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
|
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
|
||||||
[
|
[
|
||||||
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
|
||||||
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
|
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
|
||||||
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
|
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
|
||||||
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
|
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
|
||||||
@@ -115,6 +116,7 @@
|
|||||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||||
|
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||||
{{cookiecutter.camelcase_modelname}}Model,
|
{{cookiecutter.camelcase_modelname}}Model,
|
||||||
@@ -209,6 +211,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
|
|||||||
{% else -%}
|
{% else -%}
|
||||||
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
|
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||||
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||||
|
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||||
{{cookiecutter.camelcase_modelname}}Model,
|
{{cookiecutter.camelcase_modelname}}Model,
|
||||||
@@ -232,10 +235,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
|
|||||||
|
|
||||||
# Below: "# Model for Causal LM mapping"
|
# Below: "# Model for Causal LM mapping"
|
||||||
# Replace with:
|
# Replace with:
|
||||||
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
|
|
||||||
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForCausalLM),
|
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForCausalLM),
|
||||||
{% else -%}
|
|
||||||
{% endif -%}
|
|
||||||
# End.
|
# End.
|
||||||
|
|
||||||
# Below: "# Model for Masked LM mapping"
|
# Below: "# Model for Masked LM mapping"
|
||||||
@@ -384,6 +384,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
|
|||||||
{% else -%}
|
{% else -%}
|
||||||
"{{cookiecutter.camelcase_modelname}}Encoder",
|
"{{cookiecutter.camelcase_modelname}}Encoder",
|
||||||
"{{cookiecutter.camelcase_modelname}}Decoder",
|
"{{cookiecutter.camelcase_modelname}}Decoder",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}DecoderWrapper",
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
# End.
|
# End.
|
||||||
|
|
||||||
@@ -393,5 +394,6 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
|
|||||||
{% else -%}
|
{% else -%}
|
||||||
"{{cookiecutter.camelcase_modelname}}Encoder", # Building part of bigger (tested) model.
|
"{{cookiecutter.camelcase_modelname}}Encoder", # Building part of bigger (tested) model.
|
||||||
"{{cookiecutter.camelcase_modelname}}Decoder", # Building part of bigger (tested) model.
|
"{{cookiecutter.camelcase_modelname}}Decoder", # Building part of bigger (tested) model.
|
||||||
|
"{{cookiecutter.camelcase_modelname}}DecoderWrapper", # Building part of bigger (tested) model.
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
# End.
|
# End.
|
||||||
|
|||||||
@@ -121,6 +121,13 @@ Tips:
|
|||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
{{cookiecutter.camelcase_modelname}}ForCausalLM
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.{{cookiecutter.camelcase_modelname}}ForCausalLM
|
||||||
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
{% if "TensorFlow" in cookiecutter.generate_tensorflow_and_pytorch -%}
|
{% if "TensorFlow" in cookiecutter.generate_tensorflow_and_pytorch -%}
|
||||||
@@ -180,5 +187,7 @@ TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
|
|||||||
|
|
||||||
.. autoclass:: transformers.TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
|
.. autoclass:: transformers.TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
|
||||||
:members: call
|
:members: call
|
||||||
|
|
||||||
|
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
|
|||||||
Reference in New Issue
Block a user