Support for version spec in requires & arbitrary mismatching depths across folders (#37854)

* Support for version spec in requires & arbitrary mismatching depths

* Quality

* Testing
This commit is contained in:
Lysandre Debut
2025-05-09 15:26:27 +02:00
committed by GitHub
parent 774dc274ac
commit 23d79cea75
4 changed files with 391 additions and 45 deletions

View File

@@ -1,11 +1,19 @@
import os
import unittest
from pathlib import Path
from typing import Callable
from transformers.utils.import_utils import define_import_structure, spread_import_structure
import pytest
from transformers.utils.import_utils import (
Backend,
VersionComparison,
define_import_structure,
spread_import_structure,
)
import_structures = Path("import_structures")
import_structures = Path(__file__).parent / "import_structures"
def fetch__all__(file_content):
@@ -36,26 +44,39 @@ class TestImportStructures(unittest.TestCase):
models_path = base_transformers_path / "src" / "transformers" / "models"
models_import_structure = spread_import_structure(define_import_structure(models_path))
# TODO: Lysandre
# See https://app.circleci.com/pipelines/github/huggingface/transformers/104762/workflows/7ba9c6f7-a3b2-44e6-8eaf-749c7b7261f7/jobs/1393260/tests
@unittest.skip(reason="failing")
def test_definition(self):
import_structure = define_import_structure(import_structures)
import_structure_definition = {
frozenset(()): {
"import_structure_raw_register": {"A0", "a0", "A4"},
valid_frozensets: dict[frozenset | frozenset[str], dict[str, set[str]]] = {
frozenset(): {
"import_structure_raw_register": {"A0", "A4", "a0"},
"import_structure_register_with_comments": {"B0", "b0"},
},
frozenset(("tf", "torch")): {
"import_structure_raw_register": {"A1", "a1", "A2", "a2", "A3", "a3"},
"import_structure_register_with_comments": {"B1", "b1", "B2", "b2", "B3", "b3"},
frozenset({"random_item_that_should_not_exist"}): {"failing_export": {"A0"}},
frozenset({"torch"}): {
"import_structure_register_with_duplicates": {"C0", "C1", "C2", "C3", "c0", "c1", "c2", "c3"}
},
frozenset(("torch",)): {
"import_structure_register_with_duplicates": {"C0", "c0", "C1", "c1", "C2", "c2", "C3", "c3"},
frozenset({"tf", "torch"}): {
"import_structure_raw_register": {"A1", "A2", "A3", "a1", "a2", "a3"},
"import_structure_register_with_comments": {"B1", "B2", "B3", "b1", "b2", "b3"},
},
frozenset({"torch>=2.5"}): {"import_structure_raw_register_with_versions": {"D0", "d0"}},
frozenset({"torch>2.5"}): {"import_structure_raw_register_with_versions": {"D1", "d1"}},
frozenset({"torch<=2.5"}): {"import_structure_raw_register_with_versions": {"D2", "d2"}},
frozenset({"torch<2.5"}): {"import_structure_raw_register_with_versions": {"D3", "d3"}},
frozenset({"torch==2.5"}): {"import_structure_raw_register_with_versions": {"D4", "d4"}},
frozenset({"torch!=2.5"}): {"import_structure_raw_register_with_versions": {"D5", "d5"}},
frozenset({"torch>=2.5", "accelerate<0.20"}): {
"import_structure_raw_register_with_versions": {"D6", "d6"}
},
}
self.assertDictEqual(import_structure, import_structure_definition)
self.assertEqual(len(import_structure.keys()), len(valid_frozensets.keys()))
for _frozenset in valid_frozensets.keys():
self.assertTrue(_frozenset in import_structure)
self.assertListEqual(list(import_structure[_frozenset].keys()), list(valid_frozensets[_frozenset].keys()))
for module, objects in valid_frozensets[_frozenset].items():
self.assertTrue(module in import_structure[_frozenset])
self.assertSetEqual(objects, import_structure[_frozenset][module])
def test_transformers_specific_model_import(self):
"""
@@ -96,9 +117,92 @@ class TestImportStructures(unittest.TestCase):
)
self.assertListEqual(sorted(objects), sorted(_all), msg=error_message)
# TODO: Lysandre
# See https://app.circleci.com/pipelines/github/huggingface/transformers/104762/workflows/7ba9c6f7-a3b2-44e6-8eaf-749c7b7261f7/jobs/1393260/tests
@unittest.skip(reason="failing")
def test_export_backend_should_be_defined(self):
with self.assertRaisesRegex(ValueError, "Backend should be defined in the BACKENDS_MAPPING"):
pass
def test_import_spread(self):
"""
This test is specifically designed to test that varying levels of depth across import structures are
respected.
In this instance, frozensets are at respective depths of 1, 2 and 3, for example:
- models.{frozensets}
- models.albert.{frozensets}
- models.deprecated.transfo_xl.{frozensets}
"""
initial_import_structure = {
frozenset(): {"dummy_non_model": {"DummyObject"}},
"models": {
frozenset(): {"dummy_config": {"DummyConfig"}},
"albert": {
frozenset(): {"configuration_albert": {"AlbertConfig", "AlbertOnnxConfig"}},
frozenset({"torch"}): {
"modeling_albert": {
"AlbertForMaskedLM",
}
},
},
"llama": {
frozenset(): {"configuration_llama": {"LlamaConfig"}},
frozenset({"torch"}): {
"modeling_llama": {
"LlamaForCausalLM",
}
},
},
"deprecated": {
"transfo_xl": {
frozenset({"torch"}): {
"modeling_transfo_xl": {
"TransfoXLModel",
}
},
frozenset(): {
"configuration_transfo_xl": {"TransfoXLConfig"},
"tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"},
},
},
"deta": {
frozenset({"torch"}): {
"modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"}
},
frozenset(): {"configuration_deta": {"DetaConfig"}},
frozenset({"vision"}): {"image_processing_deta": {"DetaImageProcessor"}},
},
},
},
}
ground_truth_spread_import_structure = {
frozenset(): {
"dummy_non_model": {"DummyObject"},
"models.dummy_config": {"DummyConfig"},
"models.albert.configuration_albert": {"AlbertConfig", "AlbertOnnxConfig"},
"models.llama.configuration_llama": {"LlamaConfig"},
"models.deprecated.transfo_xl.configuration_transfo_xl": {"TransfoXLConfig"},
"models.deprecated.transfo_xl.tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"},
"models.deprecated.deta.configuration_deta": {"DetaConfig"},
},
frozenset({"torch"}): {
"models.albert.modeling_albert": {"AlbertForMaskedLM"},
"models.llama.modeling_llama": {"LlamaForCausalLM"},
"models.deprecated.transfo_xl.modeling_transfo_xl": {"TransfoXLModel"},
"models.deprecated.deta.modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"},
},
frozenset({"vision"}): {"models.deprecated.deta.image_processing_deta": {"DetaImageProcessor"}},
}
newly_spread_import_structure = spread_import_structure(initial_import_structure)
self.assertEqual(ground_truth_spread_import_structure, newly_spread_import_structure)
@pytest.mark.parametrize(
"backend,package_name,version_comparison,version",
[
pytest.param(Backend("torch>=2.5 "), "torch", VersionComparison.GREATER_THAN_OR_EQUAL.value, "2.5"),
pytest.param(Backend("tf<=1"), "tf", VersionComparison.LESS_THAN_OR_EQUAL.value, "1"),
pytest.param(Backend("torchvision==0.19.1"), "torchvision", VersionComparison.EQUAL.value, "0.19.1"),
],
)
def test_backend_specification(backend: Backend, package_name: str, version_comparison: Callable, version: str):
assert backend.package_name == package_name
assert VersionComparison.from_string(backend.version_comparison) == version_comparison
assert backend.version == version