Fixes in the templates (#10951)
* Fixes in the templates * Define in all cases * Dimensionality -> Dimension Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -44,16 +44,14 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
|
||||
Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.{{cookiecutter.camelcase_modelname}}Model` or
|
||||
:class:`~transformers.TF{{cookiecutter.camelcase_modelname}}Model`.
|
||||
Vocabulary size of the model. Defines the different tokens that
|
||||
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.{{cookiecutter.camelcase_modelname}}Model`.
|
||||
hidden_size (:obj:`int`, `optional`, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
Dimension of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler.
|
||||
If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
|
||||
@@ -75,14 +73,14 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if ``config.is_decoder=True``.
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
{% else -%}
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 50265):
|
||||
Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.{{cookiecutter.camelcase_modelname}}Model` or
|
||||
:class:`~transformers.TF{{cookiecutter.camelcase_modelname}}Model`.
|
||||
d_model (:obj:`int`, `optional`, defaults to 1024):
|
||||
Dimensionality of the layers and the pooler layer.
|
||||
Dimension of the layers and the pooler layer.
|
||||
encoder_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of encoder layers.
|
||||
decoder_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
@@ -92,9 +90,9 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
|
||||
decoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
|
||||
|
||||
@@ -60,6 +60,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.c
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "{{cookiecutter.checkpoint_identifier}}"
|
||||
_CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config"
|
||||
_TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
|
||||
|
||||
@@ -730,7 +731,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFBaseModelOutputWithPooling,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -807,7 +808,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFMaskedLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -903,7 +904,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
||||
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFCausalLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1031,7 +1032,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFSequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1137,7 +1138,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFMultipleChoiceModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1280,7 +1281,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFTokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1376,7 +1377,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFQuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1504,6 +1505,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.c
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "{{cookiecutter.checkpoint_identifier}}"
|
||||
_CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config"
|
||||
_TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
|
||||
|
||||
@@ -2512,7 +2514,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFSeq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
@@ -54,6 +54,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.c
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "{{cookiecutter.checkpoint_identifier}}"
|
||||
_CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config"
|
||||
_TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
|
||||
|
||||
@@ -779,7 +780,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -932,7 +933,7 @@ class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_m
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=MaskedLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1190,7 +1191,7 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1270,7 +1271,7 @@ class {{cookiecutter.camelcase_modelname}}ForMultipleChoice({{cookiecutter.camel
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=MultipleChoiceModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1360,7 +1361,7 @@ class {{cookiecutter.camelcase_modelname}}ForTokenClassification({{cookiecutter.
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1447,7 +1448,7 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=QuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -1559,6 +1560,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.c
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "{{cookiecutter.checkpoint_identifier}}"
|
||||
_CONFIG_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Config"
|
||||
_TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
|
||||
|
||||
@@ -2607,7 +2609,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=Seq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -2875,7 +2877,7 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=Seq2SeqSequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -2976,7 +2978,7 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=Seq2SeqQuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user