diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ae756c2ceb..b974552460 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -118,6 +118,7 @@ _import_structure = { "is_flax_available", "is_psutil_available", "is_py3nvml_available", + "is_pyctcdecode_available", "is_scipy_available", "is_sentencepiece_available", "is_sklearn_available", @@ -2149,6 +2150,7 @@ if TYPE_CHECKING: is_flax_available, is_psutil_available, is_py3nvml_available, + is_pyctcdecode_available, is_scipy_available, is_sentencepiece_available, is_sklearn_available, diff --git a/src/transformers/models/wav2vec2_with_lm/__init__.py b/src/transformers/models/wav2vec2_with_lm/__init__.py index b7f31c5581..4b03c83252 100644 --- a/src/transformers/models/wav2vec2_with_lm/__init__.py +++ b/src/transformers/models/wav2vec2_with_lm/__init__.py @@ -20,11 +20,11 @@ from typing import TYPE_CHECKING from ...file_utils import _LazyModule, is_pyctcdecode_available -_import_structure = {} +_import_structure = {"processing_wav2vec2_with_lm": []} if is_pyctcdecode_available(): - _import_structure["processing_wav2vec2_with_lm"] = ["Wav2Vec2ProcessorWithLM"] + _import_structure["processing_wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM") if TYPE_CHECKING: diff --git a/utils/check_inits.py b/utils/check_inits.py index 4b1dc574b5..8cfbfc18a4 100644 --- a/utils/check_inits.py +++ b/utils/check_inits.py @@ -27,8 +27,6 @@ _re_backend = re.compile(r"is\_([a-z_]*)_available()") _re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]') # Catches a line if is_foo_available _re_test_backend = re.compile(r"^\s*if\s+is\_[a-z_]*\_available\(\)") -# Catches a line _import_struct["bla"] = ["foo"] -_re_import_struct_equal_one = re.compile(r'^\s*_import_structure\["\S*"\]\ = "\[(\S*)\]"') # Catches a line _import_struct["bla"].append("foo") _re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)') # Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"] @@ -90,9 +88,7 @@ def parse_init(init_file): # Until we unindent, add backend objects to the list while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4): line = lines[line_index] - if _re_import_struct_equal_one.search(line) is not None: - objects.append(_re_import_struct_equal_one.search(line).groups()[0]) - elif _re_import_struct_add_one.search(line) is not None: + if _re_import_struct_add_one.search(line) is not None: objects.append(_re_import_struct_add_one.search(line).groups()[0]) elif _re_import_struct_add_many.search(line) is not None: imports = _re_import_struct_add_many.search(line).groups()[0].split(", ")