Try working around the processor registration bugs (#36184)
* Try working around the processor registration bugs * oops * Update error message * Clarify error * Docstring docstring docstring * The extra content is indexed by config class, so let's grab some values out of there * Commit my confusion as a TODO * Resolve my confusion * Cleanup and mostly revert to the original * Better autoclass fallback * Don't nest f-strings you lunatic * Clearer error message * Less getattr() * Revert a lot of changes to try a different approach! * Try the global registry * Check the dynamic list as well as the transformers root * Move the dynamic list somewhere safer * Move the dynamic list somewhere even safer * More import cleanup * Simplify all the register_for_auto_class methods * Set _auto_class in the register() methods * Stop setting the cls attribute in register() * Restore specifying the model class for Model derivatives only * Fix accidentally taking the .__class__ of a class * Revert register_for_auto_class changes * Fix get_possibly_dynamic_module * No more ALL_CUSTOM_CLASSES * Fix up get_possibly_dynamic_module as well * Revert unnecessary formatting changes * Trigger tests
This commit is contained in:
@@ -382,6 +382,6 @@ class AutoProcessor:
|
|||||||
Args:
|
Args:
|
||||||
config_class ([`PretrainedConfig`]):
|
config_class ([`PretrainedConfig`]):
|
||||||
The configuration corresponding to the model to register.
|
The configuration corresponding to the model to register.
|
||||||
processor_class ([`FeatureExtractorMixin`]): The processor to register.
|
processor_class ([`ProcessorMixin`]): The processor to register.
|
||||||
"""
|
"""
|
||||||
PROCESSOR_MAPPING.register(config_class, processor_class, exist_ok=exist_ok)
|
PROCESSOR_MAPPING.register(config_class, processor_class, exist_ok=exist_ok)
|
||||||
|
|||||||
@@ -990,6 +990,7 @@ class AutoTokenizer:
|
|||||||
f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING.keys())}."
|
f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING.keys())}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None, exist_ok=False):
|
def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None, exist_ok=False):
|
||||||
"""
|
"""
|
||||||
Register a new tokenizer in this mapping.
|
Register a new tokenizer in this mapping.
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ from .utils import (
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
# Dynamically import the Transformers module to grab the attribute classes of the processor form their names.
|
# Dynamically import the Transformers module to grab the attribute classes of the processor from their names.
|
||||||
transformers_module = direct_transformers_import(Path(__file__).parent)
|
transformers_module = direct_transformers_import(Path(__file__).parent)
|
||||||
|
|
||||||
|
|
||||||
@@ -470,9 +470,9 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
# Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class.
|
# Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class.
|
||||||
class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name)
|
class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name)
|
||||||
if isinstance(class_name, tuple):
|
if isinstance(class_name, tuple):
|
||||||
proper_class = tuple(getattr(transformers_module, n) for n in class_name if n is not None)
|
proper_class = tuple(self.get_possibly_dynamic_module(n) for n in class_name if n is not None)
|
||||||
else:
|
else:
|
||||||
proper_class = getattr(transformers_module, class_name)
|
proper_class = self.get_possibly_dynamic_module(class_name)
|
||||||
|
|
||||||
if not isinstance(arg, proper_class):
|
if not isinstance(arg, proper_class):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@@ -1100,11 +1100,19 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||||
|
"""
|
||||||
|
Identify and instantiate the subcomponents of Processor classes, like image processors and
|
||||||
|
tokenizers. This method uses the Processor attributes like `tokenizer_class` to figure out what class those
|
||||||
|
subcomponents should be. Note that any subcomponents must either be library classes that are accessible in
|
||||||
|
the `transformers` root, or they must be custom code that has been registered with the relevant autoclass,
|
||||||
|
via methods like `AutoTokenizer.register()`. If neither of these conditions are fulfilled, this method
|
||||||
|
will be unable to find the relevant subcomponent class and will raise an error.
|
||||||
|
"""
|
||||||
args = []
|
args = []
|
||||||
for attribute_name in cls.attributes:
|
for attribute_name in cls.attributes:
|
||||||
class_name = getattr(cls, f"{attribute_name}_class")
|
class_name = getattr(cls, f"{attribute_name}_class")
|
||||||
if isinstance(class_name, tuple):
|
if isinstance(class_name, tuple):
|
||||||
classes = tuple(getattr(transformers_module, n) if n is not None else None for n in class_name)
|
classes = tuple(cls.get_possibly_dynamic_module(n) if n is not None else None for n in class_name)
|
||||||
if attribute_name == "image_processor":
|
if attribute_name == "image_processor":
|
||||||
# TODO: @yoni, change logic in v4.50 (when use_fast set to True by default)
|
# TODO: @yoni, change logic in v4.50 (when use_fast set to True by default)
|
||||||
use_fast = kwargs.get("use_fast", None)
|
use_fast = kwargs.get("use_fast", None)
|
||||||
@@ -1121,11 +1129,35 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
else:
|
else:
|
||||||
attribute_class = classes[0]
|
attribute_class = classes[0]
|
||||||
else:
|
else:
|
||||||
attribute_class = getattr(transformers_module, class_name)
|
attribute_class = cls.get_possibly_dynamic_module(class_name)
|
||||||
|
|
||||||
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
|
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_possibly_dynamic_module(module_name):
|
||||||
|
if hasattr(transformers_module, module_name):
|
||||||
|
return getattr(transformers_module, module_name)
|
||||||
|
lookup_locations = [
|
||||||
|
transformers_module.IMAGE_PROCESSOR_MAPPING,
|
||||||
|
transformers_module.TOKENIZER_MAPPING,
|
||||||
|
transformers_module.FEATURE_EXTRACTOR_MAPPING,
|
||||||
|
]
|
||||||
|
for lookup_location in lookup_locations:
|
||||||
|
for custom_class in lookup_location._extra_content.values():
|
||||||
|
if isinstance(custom_class, tuple):
|
||||||
|
for custom_subclass in custom_class:
|
||||||
|
if custom_subclass is not None and custom_subclass.__name__ == module_name:
|
||||||
|
return custom_subclass
|
||||||
|
elif custom_class is not None and custom_class.__name__ == module_name:
|
||||||
|
return custom_class
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not find module {module_name} in `transformers`. If this is a custom class, "
|
||||||
|
f"it should be registered using the relevant `AutoClass.register()` function so that "
|
||||||
|
f"other functions can find it!"
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_input_names(self):
|
def model_input_names(self):
|
||||||
first_attribute = getattr(self, self.attributes[0])
|
first_attribute = getattr(self, self.attributes[0])
|
||||||
|
|||||||
@@ -354,6 +354,40 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||||||
if CustomConfig in PROCESSOR_MAPPING._extra_content:
|
if CustomConfig in PROCESSOR_MAPPING._extra_content:
|
||||||
del PROCESSOR_MAPPING._extra_content[CustomConfig]
|
del PROCESSOR_MAPPING._extra_content[CustomConfig]
|
||||||
|
|
||||||
|
def test_dynamic_processor_with_specific_dynamic_subcomponents(self):
|
||||||
|
class NewFeatureExtractor(Wav2Vec2FeatureExtractor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class NewTokenizer(BertTokenizer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class NewProcessor(ProcessorMixin):
|
||||||
|
feature_extractor_class = "NewFeatureExtractor"
|
||||||
|
tokenizer_class = "NewTokenizer"
|
||||||
|
|
||||||
|
def __init__(self, feature_extractor, tokenizer):
|
||||||
|
super().__init__(feature_extractor, tokenizer)
|
||||||
|
|
||||||
|
try:
|
||||||
|
AutoConfig.register("custom", CustomConfig)
|
||||||
|
AutoFeatureExtractor.register(CustomConfig, NewFeatureExtractor)
|
||||||
|
AutoTokenizer.register(CustomConfig, slow_tokenizer_class=NewTokenizer)
|
||||||
|
AutoProcessor.register(CustomConfig, NewProcessor)
|
||||||
|
# If remote code is not set, the default is to use local classes.
|
||||||
|
processor = AutoProcessor.from_pretrained(
|
||||||
|
"hf-internal-testing/test_dynamic_processor",
|
||||||
|
)
|
||||||
|
self.assertEqual(processor.__class__.__name__, "NewProcessor")
|
||||||
|
finally:
|
||||||
|
if "custom" in CONFIG_MAPPING._extra_content:
|
||||||
|
del CONFIG_MAPPING._extra_content["custom"]
|
||||||
|
if CustomConfig in FEATURE_EXTRACTOR_MAPPING._extra_content:
|
||||||
|
del FEATURE_EXTRACTOR_MAPPING._extra_content[CustomConfig]
|
||||||
|
if CustomConfig in TOKENIZER_MAPPING._extra_content:
|
||||||
|
del TOKENIZER_MAPPING._extra_content[CustomConfig]
|
||||||
|
if CustomConfig in PROCESSOR_MAPPING._extra_content:
|
||||||
|
del PROCESSOR_MAPPING._extra_content[CustomConfig]
|
||||||
|
|
||||||
def test_auto_processor_creates_tokenizer(self):
|
def test_auto_processor_creates_tokenizer(self):
|
||||||
processor = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-bert")
|
processor = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
self.assertEqual(processor.__class__.__name__, "BertTokenizerFast")
|
self.assertEqual(processor.__class__.__name__, "BertTokenizerFast")
|
||||||
|
|||||||
Reference in New Issue
Block a user