Add a fix for custom code tokenizers in pipelines (#32300)
* Add a fix for the case when tokenizers are passed as a string * Support image processors and feature extractors as well * Reverting load_feature_extractor and load_image_processor * Add test * Test is torch-only * Add tests for preprocessors and feature extractors and move test * Extremely experimental fix * Revert that change, wrong branch! * Typo! * Split tests
This commit is contained in:
@@ -904,7 +904,11 @@ def pipeline(
|
|||||||
|
|
||||||
model_config = model.config
|
model_config = model.config
|
||||||
hub_kwargs["_commit_hash"] = model.config._commit_hash
|
hub_kwargs["_commit_hash"] = model.config._commit_hash
|
||||||
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
|
load_tokenizer = (
|
||||||
|
type(model_config) in TOKENIZER_MAPPING
|
||||||
|
or model_config.tokenizer_class is not None
|
||||||
|
or isinstance(tokenizer, str)
|
||||||
|
)
|
||||||
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
|
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
|
||||||
load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None
|
load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None
|
||||||
|
|
||||||
|
|||||||
@@ -26,10 +26,13 @@ from huggingface_hub import HfFolder, delete_repo
|
|||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
AutomaticSpeechRecognitionPipeline,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
DistilBertForSequenceClassification,
|
DistilBertForSequenceClassification,
|
||||||
|
MaskGenerationPipeline,
|
||||||
TextClassificationPipeline,
|
TextClassificationPipeline,
|
||||||
|
TextGenerationPipeline,
|
||||||
TFAutoModelForSequenceClassification,
|
TFAutoModelForSequenceClassification,
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
@@ -859,6 +862,42 @@ class CustomPipelineTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(self.COUNT, 1)
|
self.assertEqual(self.COUNT, 1)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_custom_code_with_string_tokenizer(self):
|
||||||
|
# This test checks for an edge case - tokenizer loading used to fail when using a custom code model
|
||||||
|
# with a separate tokenizer that was passed as a repo name rather than a tokenizer object.
|
||||||
|
# See https://github.com/huggingface/transformers/issues/31669
|
||||||
|
text_generator = pipeline(
|
||||||
|
"text-generation",
|
||||||
|
model="Rocketknight1/fake-custom-model-test",
|
||||||
|
tokenizer="Rocketknight1/fake-custom-model-test",
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsInstance(text_generator, TextGenerationPipeline) # Assert successful loading
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_custom_code_with_string_feature_extractor(self):
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
"automatic-speech-recognition",
|
||||||
|
model="Rocketknight1/fake-custom-wav2vec2",
|
||||||
|
feature_extractor="Rocketknight1/fake-custom-wav2vec2",
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsInstance(speech_recognizer, AutomaticSpeechRecognitionPipeline) # Assert successful loading
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_custom_code_with_string_preprocessor(self):
|
||||||
|
mask_generator = pipeline(
|
||||||
|
"mask-generation",
|
||||||
|
model="Rocketknight1/fake-custom-sam",
|
||||||
|
processor="Rocketknight1/fake-custom-sam",
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsInstance(mask_generator, MaskGenerationPipeline) # Assert successful loading
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user