Fix model templates (#9999)

This commit is contained in:
Lysandre Debut
2021-02-04 13:47:26 +01:00
committed by GitHub
parent 804cd185d8
commit e89c959af9
11 changed files with 32 additions and 17 deletions

View File

@@ -1656,7 +1656,7 @@ class BartForCausalLM(BartPretrainedModel):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it. provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See Indices can be obtained using :class:`~transformers.BartTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details. for details.

View File

@@ -1425,7 +1425,7 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it. provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See Indices can be obtained using :class:`~transformers.BlenderbotTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details. for details.

View File

@@ -1400,7 +1400,7 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it. provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See Indices can be obtained using :class:`~transformers.BlenderbotSmallTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details. for details.

View File

@@ -1411,7 +1411,7 @@ class MarianForCausalLM(MarianPreTrainedModel):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it. provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See Indices can be obtained using :class:`~transformers.MarianTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details. for details.

View File

@@ -1658,7 +1658,7 @@ class MBartForCausalLM(MBartPreTrainedModel):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it. provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See Indices can be obtained using :class:`~transformers.MBartTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details. for details.

View File

@@ -1414,7 +1414,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it. provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See Indices can be obtained using :class:`~transformers.PegasusTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details. for details.

View File

@@ -55,6 +55,7 @@ if is_torch_available():
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration", "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering", "{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification", "{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
"{{cookiecutter.camelcase_modelname}}Model", "{{cookiecutter.camelcase_modelname}}Model",
"{{cookiecutter.camelcase_modelname}}PreTrainedModel", "{{cookiecutter.camelcase_modelname}}PreTrainedModel",
] ]
@@ -114,6 +115,7 @@ if TYPE_CHECKING:
from .modeling_{{cookiecutter.lowercase_modelname}} import ( from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification, {{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model, {{cookiecutter.camelcase_modelname}}Model,

View File

@@ -1546,6 +1546,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput, Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput, Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput, Seq2SeqSequenceClassifierOutput,
CausalLMOutputWithCrossAttentions
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
@@ -1952,7 +1953,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}ClassificationHead with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module): class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks.""" """Head for sentence-level classification tasks."""
@@ -3066,8 +3067,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
) )
# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}DecoderWrapper with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcase_modelname}}PretrainedModel): class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcase_modelname}}PreTrainedModel):
""" """
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the :class:`~transformers.EncoderDecoderModel` framework. used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
@@ -3081,8 +3082,8 @@ class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcas
return self.decoder(*args, **kwargs) return self.decoder(*args, **kwargs)
# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}ForCausalLM with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_modelname}}PretrainedModel): class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
config = copy.deepcopy(config) config = copy.deepcopy(config)
@@ -3199,8 +3200,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}ForCausalLM >>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}ForCausalLM
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}') >>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('facebook/bart-large')
>>> model = {{cookiecutter.camelcase_modelname}}ForCausalLM.from_pretrained('{{cookiecutter.checkpoint_identifier}}', add_cross_attention=False) >>> model = {{cookiecutter.camelcase_modelname}}ForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)

View File

@@ -488,7 +488,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor, floats_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available(): if is_torch_available():
@@ -498,6 +498,7 @@ if is_torch_available():
{{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Config,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification, {{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model, {{cookiecutter.camelcase_modelname}}Model,
{{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}Tokenizer,

View File

@@ -47,6 +47,7 @@
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend( _import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
[ [
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST", "{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration", "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering", "{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification", "{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
@@ -115,6 +116,7 @@
from .models.{{cookiecutter.lowercase_modelname}} import ( from .models.{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification, {{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model, {{cookiecutter.camelcase_modelname}}Model,
@@ -209,6 +211,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
{% else -%} {% else -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import ( from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification, {{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model, {{cookiecutter.camelcase_modelname}}Model,
@@ -232,10 +235,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
# Below: "# Model for Causal LM mapping" # Below: "# Model for Causal LM mapping"
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForCausalLM), ({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForCausalLM),
{% else -%}
{% endif -%}
# End. # End.
# Below: "# Model for Masked LM mapping" # Below: "# Model for Masked LM mapping"
@@ -384,6 +384,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
{% else -%} {% else -%}
"{{cookiecutter.camelcase_modelname}}Encoder", "{{cookiecutter.camelcase_modelname}}Encoder",
"{{cookiecutter.camelcase_modelname}}Decoder", "{{cookiecutter.camelcase_modelname}}Decoder",
"{{cookiecutter.camelcase_modelname}}DecoderWrapper",
{% endif -%} {% endif -%}
# End. # End.
@@ -393,5 +394,6 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
{% else -%} {% else -%}
"{{cookiecutter.camelcase_modelname}}Encoder", # Building part of bigger (tested) model. "{{cookiecutter.camelcase_modelname}}Encoder", # Building part of bigger (tested) model.
"{{cookiecutter.camelcase_modelname}}Decoder", # Building part of bigger (tested) model. "{{cookiecutter.camelcase_modelname}}Decoder", # Building part of bigger (tested) model.
"{{cookiecutter.camelcase_modelname}}DecoderWrapper", # Building part of bigger (tested) model.
{% endif -%} {% endif -%}
# End. # End.

View File

@@ -121,6 +121,13 @@ Tips:
:members: forward :members: forward
{{cookiecutter.camelcase_modelname}}ForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.{{cookiecutter.camelcase_modelname}}ForCausalLM
:members: forward
{% endif -%} {% endif -%}
{% endif -%} {% endif -%}
{% if "TensorFlow" in cookiecutter.generate_tensorflow_and_pytorch -%} {% if "TensorFlow" in cookiecutter.generate_tensorflow_and_pytorch -%}
@@ -180,5 +187,7 @@ TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
.. autoclass:: transformers.TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration .. autoclass:: transformers.TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
:members: call :members: call
{% endif -%} {% endif -%}
{% endif -%} {% endif -%}