Import structure & first three model refactors (#31329)
* Import structure & first three model refactors * Register -> Export. Export all in __all__. Sensible defaults according to filename. * Apply most comments from Amy and some comments from Lucain Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Lucain Pouget <lucainp@gmail.com> * Style * Add comment * Clearer .py management * Raise if not in backend mapping * More specific type * More efficient listdir * Misc fixes --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Lucain Pouget <lucainp@gmail.com>
This commit is contained in:
@@ -34,8 +34,8 @@ if is_torch_available():
|
||||
LongformerForSequenceClassification,
|
||||
LongformerForTokenClassification,
|
||||
LongformerModel,
|
||||
LongformerSelfAttention,
|
||||
)
|
||||
from transformers.models.longformer.modeling_longformer import LongformerSelfAttention
|
||||
|
||||
|
||||
class LongformerModelTester:
|
||||
|
||||
@@ -37,8 +37,8 @@ if is_tf_available():
|
||||
TFLongformerForSequenceClassification,
|
||||
TFLongformerForTokenClassification,
|
||||
TFLongformerModel,
|
||||
TFLongformerSelfAttention,
|
||||
)
|
||||
from transformers.models.longformer.modeling_tf_longformer import TFLongformerSelfAttention
|
||||
from transformers.tf_utils import shape_list
|
||||
|
||||
|
||||
|
||||
@@ -40,11 +40,11 @@ if is_torch_available():
|
||||
ReformerForMaskedLM,
|
||||
ReformerForQuestionAnswering,
|
||||
ReformerForSequenceClassification,
|
||||
ReformerLayer,
|
||||
ReformerModel,
|
||||
ReformerModelWithLMHead,
|
||||
ReformerTokenizer,
|
||||
)
|
||||
from transformers.models.reformer.modeling_reformer import ReformerLayer
|
||||
|
||||
|
||||
class ReformerModelTester:
|
||||
|
||||
23
tests/utils/import_structures/failing_export.py
Normal file
23
tests/utils/import_structures/failing_export.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# fmt: off
|
||||
|
||||
from transformers.utils.import_utils import export
|
||||
|
||||
|
||||
@export(backends=("random_item_that_should_not_exist",))
|
||||
class A0:
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -0,0 +1,80 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# fmt: off
|
||||
|
||||
from transformers.utils.import_utils import export
|
||||
|
||||
|
||||
@export()
|
||||
class A0:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@export()
|
||||
def a0():
|
||||
pass
|
||||
|
||||
|
||||
@export(backends=("torch", "tf"))
|
||||
class A1:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@export(backends=("torch", "tf"))
|
||||
def a1():
|
||||
pass
|
||||
|
||||
|
||||
@export(
|
||||
backends=("torch", "tf")
|
||||
)
|
||||
class A2:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@export(
|
||||
backends=("torch", "tf")
|
||||
)
|
||||
def a2():
|
||||
pass
|
||||
|
||||
|
||||
@export(
|
||||
backends=(
|
||||
"torch",
|
||||
"tf"
|
||||
)
|
||||
)
|
||||
class A3:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@export(
|
||||
backends=(
|
||||
"torch",
|
||||
"tf"
|
||||
)
|
||||
)
|
||||
def a3():
|
||||
pass
|
||||
|
||||
@export(backends=())
|
||||
class A4:
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -0,0 +1,79 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# fmt: off
|
||||
|
||||
from transformers.utils.import_utils import export
|
||||
|
||||
|
||||
@export()
|
||||
# That's a statement
|
||||
class B0:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@export()
|
||||
# That's a statement
|
||||
def b0():
|
||||
pass
|
||||
|
||||
|
||||
@export(backends=("torch", "tf"))
|
||||
# That's a statement
|
||||
class B1:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@export(backends=("torch", "tf"))
|
||||
# That's a statement
|
||||
def b1():
|
||||
pass
|
||||
|
||||
|
||||
@export(backends=("torch", "tf"))
|
||||
# That's a statement
|
||||
class B2:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@export(backends=("torch", "tf"))
|
||||
# That's a statement
|
||||
def b2():
|
||||
pass
|
||||
|
||||
|
||||
@export(
|
||||
backends=(
|
||||
"torch",
|
||||
"tf"
|
||||
)
|
||||
)
|
||||
# That's a statement
|
||||
class B3:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@export(
|
||||
backends=(
|
||||
"torch",
|
||||
"tf"
|
||||
)
|
||||
)
|
||||
# That's a statement
|
||||
def b3():
|
||||
pass
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# fmt: off
|
||||
|
||||
from transformers.utils.import_utils import export
|
||||
|
||||
|
||||
@export(backends=("torch", "torch"))
|
||||
class C0:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@export(backends=("torch", "torch"))
|
||||
def c0():
|
||||
pass
|
||||
|
||||
|
||||
@export(backends=("torch", "torch"))
|
||||
# That's a statement
|
||||
class C1:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@export(backends=("torch", "torch"))
|
||||
# That's a statement
|
||||
def c1():
|
||||
pass
|
||||
|
||||
|
||||
@export(backends=("torch", "torch"))
|
||||
# That's a statement
|
||||
class C2:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@export(backends=("torch", "torch"))
|
||||
# That's a statement
|
||||
def c2():
|
||||
pass
|
||||
|
||||
|
||||
@export(
|
||||
backends=(
|
||||
"torch",
|
||||
"torch"
|
||||
)
|
||||
)
|
||||
# That's a statement
|
||||
class C3:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@export(
|
||||
backends=(
|
||||
"torch",
|
||||
"torch"
|
||||
)
|
||||
)
|
||||
# That's a statement
|
||||
def c3():
|
||||
pass
|
||||
98
tests/utils/test_import_structure.py
Normal file
98
tests/utils/test_import_structure.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.utils.import_utils import define_import_structure, spread_import_structure
|
||||
|
||||
|
||||
import_structures = Path("import_structures")
|
||||
|
||||
|
||||
def fetch__all__(file_content):
|
||||
"""
|
||||
Returns the content of the __all__ variable in the file content.
|
||||
Returns None if not defined, otherwise returns a list of strings.
|
||||
"""
|
||||
lines = file_content.split("\n")
|
||||
for line_index in range(len(lines)):
|
||||
line = lines[line_index]
|
||||
if line.startswith("__all__ = "):
|
||||
# __all__ is defined on a single line
|
||||
if line.endswith("]"):
|
||||
return [obj.strip("\"' ") for obj in line.split("=")[1].strip(" []").split(",")]
|
||||
|
||||
# __all__ is defined on multiple lines
|
||||
else:
|
||||
_all = []
|
||||
for __all__line_index in range(line_index + 1, len(lines)):
|
||||
if lines[__all__line_index].strip() == "]":
|
||||
return _all
|
||||
else:
|
||||
_all.append(lines[__all__line_index].strip("\"', "))
|
||||
|
||||
|
||||
class TestImportStructures(unittest.TestCase):
|
||||
base_transformers_path = Path(__file__).parent.parent.parent
|
||||
models_path = base_transformers_path / "src" / "transformers" / "models"
|
||||
models_import_structure = spread_import_structure(define_import_structure(models_path))
|
||||
|
||||
def test_definition(self):
|
||||
import_structure = define_import_structure(import_structures)
|
||||
import_structure_definition = {
|
||||
frozenset(()): {
|
||||
"import_structure_raw_register": {"A0", "a0", "A4"},
|
||||
"import_structure_register_with_comments": {"B0", "b0"},
|
||||
},
|
||||
frozenset(("tf", "torch")): {
|
||||
"import_structure_raw_register": {"A1", "a1", "A2", "a2", "A3", "a3"},
|
||||
"import_structure_register_with_comments": {"B1", "b1", "B2", "b2", "B3", "b3"},
|
||||
},
|
||||
frozenset(("torch",)): {
|
||||
"import_structure_register_with_duplicates": {"C0", "c0", "C1", "c1", "C2", "c2", "C3", "c3"},
|
||||
},
|
||||
}
|
||||
|
||||
self.assertDictEqual(import_structure, import_structure_definition)
|
||||
|
||||
def test_transformers_specific_model_import(self):
|
||||
"""
|
||||
This test ensures that there is equivalence between what is written down in __all__ and what is
|
||||
written down with register().
|
||||
|
||||
It doesn't test the backends attributed to register().
|
||||
"""
|
||||
for architecture in os.listdir(self.models_path):
|
||||
if (
|
||||
os.path.isfile(self.models_path / architecture)
|
||||
or architecture.startswith("_")
|
||||
or architecture == "deprecated"
|
||||
):
|
||||
continue
|
||||
|
||||
with self.subTest(f"Testing arch {architecture}"):
|
||||
import_structure = define_import_structure(self.models_path / architecture)
|
||||
backend_agnostic_import_structure = {}
|
||||
for requirement, module_object_mapping in import_structure.items():
|
||||
for module, objects in module_object_mapping.items():
|
||||
if module not in backend_agnostic_import_structure:
|
||||
backend_agnostic_import_structure[module] = []
|
||||
|
||||
backend_agnostic_import_structure[module].extend(objects)
|
||||
|
||||
for module, objects in backend_agnostic_import_structure.items():
|
||||
with open(self.models_path / architecture / f"{module}.py") as f:
|
||||
content = f.read()
|
||||
_all = fetch__all__(content)
|
||||
|
||||
if _all is None:
|
||||
raise ValueError(f"{module} doesn't have __all__ defined.")
|
||||
|
||||
error_message = (
|
||||
f"self.models_path / architecture / f'{module}.py doesn't seem to be defined correctly:\n"
|
||||
f"Defined in __all__: {sorted(_all)}\nDefined with register: {sorted(objects)}"
|
||||
)
|
||||
self.assertListEqual(sorted(objects), sorted(_all), msg=error_message)
|
||||
|
||||
def test_export_backend_should_be_defined(self):
|
||||
with self.assertRaisesRegex(ValueError, "Backend should be defined in the BACKENDS_MAPPING"):
|
||||
pass
|
||||
Reference in New Issue
Block a user