[WIP] Fix Pyright static type checking by replacing if-else imports with try-except (#16578)
* rebase and isort * modify cookiecutter init * fix cookiecutter auto imports * fix clean_frameworks_in_init * fix add_model_to_main_init * blackify * replace unnecessary f-strings * update yolos imports * fix roberta import bug * fix yolos missing dependency * fix add_model_like and cookiecutter bug * fix repository consistency error * modify cookiecutter, fix add_new_model_like * remove stale line Co-authored-by: Dom Miketa <dmiketa@exscientia.co.uk>
This commit is contained in:
@@ -18,15 +18,23 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# rely on isort to merge the imports
|
||||
from ...utils import _LazyModule, is_tokenizers_available
|
||||
from ...utils import _LazyModule, OptionalDependencyNotAvailable, is_tokenizers_available
|
||||
|
||||
|
||||
{%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
|
||||
from ...utils import is_tf_available
|
||||
|
||||
|
||||
{% endif %}
|
||||
{%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
{% endif %}
|
||||
{%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
|
||||
from ...utils import is_flax_available
|
||||
|
||||
|
||||
{% endif %}
|
||||
|
||||
_import_structure = {
|
||||
@@ -34,12 +42,22 @@ _import_structure = {
|
||||
"tokenization_{{cookiecutter.lowercase_modelname}}": ["{{cookiecutter.camelcase_modelname}}Tokenizer"],
|
||||
}
|
||||
|
||||
if is_tokenizers_available():
|
||||
try:
|
||||
if not is_tokenizers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["tokenization_{{cookiecutter.lowercase_modelname}}_fast"] = ["{{cookiecutter.camelcase_modelname}}TokenizerFast"]
|
||||
|
||||
{%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_{{cookiecutter.lowercase_modelname}}"] = [
|
||||
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"{{cookiecutter.camelcase_modelname}}ForMaskedLM",
|
||||
@@ -54,7 +72,12 @@ if is_torch_available():
|
||||
"load_tf_weights_in_{{cookiecutter.lowercase_modelname}}",
|
||||
]
|
||||
{% else %}
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_{{cookiecutter.lowercase_modelname}}"] = [
|
||||
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
|
||||
@@ -70,7 +93,12 @@ if is_torch_available():
|
||||
|
||||
{%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
if is_tf_available():
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_{{cookiecutter.lowercase_modelname}}"] = [
|
||||
"TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TF{{cookiecutter.camelcase_modelname}}ForMaskedLM",
|
||||
@@ -84,7 +112,12 @@ if is_tf_available():
|
||||
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
|
||||
]
|
||||
{% else %}
|
||||
if is_tf_available():
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_{{cookiecutter.lowercase_modelname}}"] = [
|
||||
"TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
|
||||
"TF{{cookiecutter.camelcase_modelname}}Model",
|
||||
@@ -96,7 +129,12 @@ if is_tf_available():
|
||||
|
||||
{%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
if is_flax_available():
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_{{cookiecutter.lowercase_modelname}}"] = [
|
||||
"Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM",
|
||||
"Flax{{cookiecutter.camelcase_modelname}}ForCausalLM",
|
||||
@@ -109,7 +147,12 @@ if is_flax_available():
|
||||
"Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel",
|
||||
]
|
||||
{% else %}
|
||||
if is_flax_available():
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_{{cookiecutter.lowercase_modelname}}"] = [
|
||||
"Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
|
||||
"Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
|
||||
@@ -125,12 +168,22 @@ if TYPE_CHECKING:
|
||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
|
||||
from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer
|
||||
|
||||
if is_tokenizers_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .tokenization_{{cookiecutter.lowercase_modelname}}_fast import {{cookiecutter.camelcase_modelname}}TokenizerFast
|
||||
|
||||
{%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
@@ -145,7 +198,12 @@ if TYPE_CHECKING:
|
||||
load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
|
||||
)
|
||||
{% else %}
|
||||
if is_torch_available():
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||
@@ -159,7 +217,12 @@ if TYPE_CHECKING:
|
||||
{% endif %}
|
||||
{%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
if is_tf_available():
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
|
||||
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
@@ -173,7 +236,12 @@ if TYPE_CHECKING:
|
||||
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
||||
)
|
||||
{% else %}
|
||||
if is_tf_available():
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
|
||||
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||
TF{{cookiecutter.camelcase_modelname}}Model,
|
||||
@@ -183,7 +251,12 @@ if TYPE_CHECKING:
|
||||
{% endif %}
|
||||
{%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
if is_flax_available():
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||
Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
Flax{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
@@ -196,7 +269,12 @@ if TYPE_CHECKING:
|
||||
Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
||||
)
|
||||
{% else %}
|
||||
if is_flax_available():
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||
Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||
Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
|
||||
@@ -115,7 +115,7 @@
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
||||
# Below: " # Fast tokenizers"
|
||||
# Below: " # Fast tokenizers structure"
|
||||
# Replace with:
|
||||
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].append("{{cookiecutter.camelcase_modelname}}TokenizerFast")
|
||||
# End.
|
||||
@@ -126,7 +126,7 @@
|
||||
# End.
|
||||
|
||||
# To replace in: "src/transformers/__init__.py"
|
||||
# Below: " if is_torch_available():" if generating PyTorch
|
||||
# Below: " # PyTorch model imports" if generating PyTorch
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||
@@ -155,7 +155,7 @@
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
||||
# Below: " if is_tf_available():" if generating TensorFlow
|
||||
# Below: " # TensorFlow model imports" if generating TensorFlow
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||
@@ -179,7 +179,7 @@
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
||||
# Below: " if is_flax_available():" if generating Flax
|
||||
# Below: " # Flax model imports" if generating Flax
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||
@@ -204,7 +204,7 @@
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
||||
# Below: " if is_tokenizers_available():"
|
||||
# Below: " # Fast tokenizers imports"
|
||||
# Replace with:
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast
|
||||
# End.
|
||||
|
||||
Reference in New Issue
Block a user