Fix model templates (#9999)

This commit is contained in:
Lysandre Debut
2021-02-04 13:47:26 +01:00
committed by GitHub
parent 804cd185d8
commit e89c959af9
11 changed files with 32 additions and 17 deletions

View File

@@ -55,6 +55,7 @@ if is_torch_available():
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
"{{cookiecutter.camelcase_modelname}}Model",
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
@@ -114,6 +115,7 @@ if TYPE_CHECKING:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model,

View File

@@ -1546,6 +1546,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
CausalLMOutputWithCrossAttentions
)
from ...modeling_utils import PreTrainedModel
from ...utils import logging
@@ -1952,7 +1953,7 @@ class {{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
return outputs
# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}ClassificationHead with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}}
# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
@@ -3066,8 +3067,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
encoder_attentions=outputs.encoder_attentions,
)
# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}DecoderWrapper with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcase_modelname}}PretrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcase_modelname}}PreTrainedModel):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
@@ -3081,8 +3082,8 @@ class {{cookiecutter.camelcase_modelname}}DecoderWrapper({{cookiecutter.camelcas
return self.decoder(*args, **kwargs)
# Copied from transformers.models.bart.modeling_bart.{{cookiecutter.camelcase_modelname}}ForCausalLM with {{cookiecutter.camelcase_modelname}}->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_modelname}}PretrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_modelname}}PreTrainedModel):
def __init__(self, config):
super().__init__(config)
config = copy.deepcopy(config)
@@ -3199,8 +3200,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}ForCausalLM
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> model = {{cookiecutter.camelcase_modelname}}ForCausalLM.from_pretrained('{{cookiecutter.checkpoint_identifier}}', add_cross_attention=False)
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('facebook/bart-large')
>>> model = {{cookiecutter.camelcase_modelname}}ForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)

View File

@@ -488,7 +488,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor, floats_tensor
from .test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
@@ -498,6 +498,7 @@ if is_torch_available():
{{cookiecutter.camelcase_modelname}}Config,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model,
{{cookiecutter.camelcase_modelname}}Tokenizer,

View File

@@ -47,6 +47,7 @@
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
[
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
@@ -115,6 +116,7 @@
from .models.{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model,
@@ -209,6 +211,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
{% else -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model,
@@ -232,10 +235,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
# Below: "# Model for Causal LM mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForCausalLM),
{% else -%}
{% endif -%}
# End.
# Below: "# Model for Masked LM mapping"
@@ -384,6 +384,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
{% else -%}
"{{cookiecutter.camelcase_modelname}}Encoder",
"{{cookiecutter.camelcase_modelname}}Decoder",
"{{cookiecutter.camelcase_modelname}}DecoderWrapper",
{% endif -%}
# End.
@@ -393,5 +394,6 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
{% else -%}
"{{cookiecutter.camelcase_modelname}}Encoder", # Building part of bigger (tested) model.
"{{cookiecutter.camelcase_modelname}}Decoder", # Building part of bigger (tested) model.
"{{cookiecutter.camelcase_modelname}}DecoderWrapper", # Building part of bigger (tested) model.
{% endif -%}
# End.

View File

@@ -121,6 +121,13 @@ Tips:
:members: forward
{{cookiecutter.camelcase_modelname}}ForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.{{cookiecutter.camelcase_modelname}}ForCausalLM
:members: forward
{% endif -%}
{% endif -%}
{% if "TensorFlow" in cookiecutter.generate_tensorflow_and_pytorch -%}
@@ -180,5 +187,7 @@ TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
.. autoclass:: transformers.TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
:members: call
{% endif -%}
{% endif -%}