Fix model templates (#8595)

* First fixes

* Fix imports and add init

* Fix typo

* Move init to final dest

* Fix tokenization import

* More fixes

* Styling
This commit is contained in:
Sylvain Gugger
2020-11-17 10:35:38 -05:00
committed by GitHub
parent 042a6aa777
commit 36a19915ea
9 changed files with 91 additions and 40 deletions

View File

@@ -0,0 +1,43 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
{%- if cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" %}
from ...file_utils import is_tf_available, is_torch_available
{%- elif cookiecutter.generate_tensorflow_and_pytorch == "PyTorch" %}
from ...file_utils import is_torch_available
{%- elif cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow" %}
from ...file_utils import is_tf_available
{% endif %}
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer
{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "PyTorch") %}
if is_torch_available():
from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
{{cookiecutter.camelcase_modelname}}Layer,
{{cookiecutter.camelcase_modelname}}Model,
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
)
{% endif %}
{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow") %}
if is_tf_available():
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
TF{{cookiecutter.camelcase_modelname}}Layer,
TF{{cookiecutter.camelcase_modelname}}Model,
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
)
{% endif %}

View File

@@ -14,8 +14,8 @@
# limitations under the License.
""" {{cookiecutter.modelname}} model configuration """
from .configuration_utils import PretrainedConfig
from .utils import logging
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)

View File

@@ -17,15 +17,14 @@
import tensorflow as tf
from .activations_tf import get_tf_activation
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
from .file_utils import (
from ...activations_tf import get_tf_activation
from ...file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from .modeling_tf_outputs import (
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
TFMaskedLMOutput,
@@ -34,7 +33,7 @@ from .modeling_tf_outputs import (
TFSequenceClassifierOutput,
TFTokenClassifierOutput,
)
from .modeling_tf_utils import (
from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss,
TFPreTrainedModel,
@@ -46,8 +45,9 @@ from .modeling_tf_utils import (
keras_serializable,
shape_list,
)
from .tokenization_utils import BatchEncoding
from .utils import logging
from ...tokenization_utils import BatchEncoding
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
logger = logging.get_logger(__name__)

View File

@@ -25,13 +25,12 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
from .file_utils import (
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from .modeling_outputs import (
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
@@ -40,15 +39,16 @@ from .modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import (
from ...modeling_utils import (
PreTrainedModel,
SequenceSummary,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from .utils import logging
from .activations import ACT2FN
from ...utils import logging
from ...activations import ACT2FN
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
logger = logging.get_logger(__name__)

View File

@@ -14,7 +14,7 @@
# To replace in: "src/transformers/__init__.py"
# Below: "if is_torch_available():" if generating PyTorch
# Replace with:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
from .models.{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
@@ -30,7 +30,7 @@
# Below: "if is_tf_available():" if generating TensorFlow
# Replace with:
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
from .models.{{cookiecutter.lowercase_modelname}} import (
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
@@ -44,14 +44,14 @@
# End.
# Below: "from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
# Below: "from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
# Replace with:
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
# End.
# To replace in: "src/transformers/configuration_auto.py"
# To replace in: "src/transformers/models/auto/configuration_auto.py"
# Below: "# Add configs here"
# Replace with:
("{{cookiecutter.lowercase_modelname}}", {{cookiecutter.camelcase_modelname}}Config),
@@ -62,9 +62,9 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.u
{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP,
# End.
# Below: "from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig",
# Below: "from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig",
# Replace with:
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
# End.
# Below: "# Add full (and cased) model names here"
@@ -83,7 +83,7 @@ from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.u
# Below: "# Add modeling imports here"
# Replace with:
from .modeling_{{cookiecutter.lowercase_modelname}} import (
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
@@ -138,7 +138,7 @@ from .modeling_{{cookiecutter.lowercase_modelname}} import (
# Below: "# Add modeling imports here"
# Replace with:
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,

View File

@@ -15,8 +15,9 @@
"""Tokenization classes for {{cookiecutter.modelname}}."""
{%- if cookiecutter.tokenizer_type == "Based on BERT" %}
from .tokenization_bert import BertTokenizer, BertTokenizerFast
from .utils import logging
from ...utils import logging
from ..bert.tokenization_bert import BertTokenizer
from ..bert.tokenization_bert_fast import BertTokenizerFast
logger = logging.get_logger(__name__)
@@ -73,14 +74,14 @@ class {{cookiecutter.camelcase_modelname}}TokenizerFast(BertTokenizerFast):
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
{%- elif cookiecutter.tokenizer_type == "Standalone" %}
import warnings
from typing import List, Optional
from tokenizers import ByteLevelBPETokenizer
from .tokenization_utils import AddedToken, PreTrainedTokenizer
from .tokenization_utils_base import BatchEncoding
from .tokenization_utils_fast import PreTrainedTokenizerFast
from typing import List, Optional
from .utils import logging
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
logger = logging.get_logger(__name__)