From 0a9300f474e17dd6e05635d1742273cf0396d9ea Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 17 May 2024 16:51:31 +0200 Subject: [PATCH] Support arbitrary processor (#30875) * Support arbitrary processor * fix * nit * update * nit * nit * fix and revert * add a small test * better check * fixup * bug so let's just use class for now * oups * . --- .../models/llava/processing_llava.py | 4 +-- tests/models/llava/test_processor_llava.py | 30 +++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 tests/models/llava/test_processor_llava.py diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 62a46acd39..ff010f7442 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -41,8 +41,8 @@ class LlavaProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - image_processor_class = "CLIPImageProcessor" - tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" def __init__(self, image_processor=None, tokenizer=None): super().__init__(image_processor, tokenizer) diff --git a/tests/models/llava/test_processor_llava.py b/tests/models/llava/test_processor_llava.py new file mode 100644 index 0000000000..068971015e --- /dev/null +++ b/tests/models/llava/test_processor_llava.py @@ -0,0 +1,30 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from transformers.testing_utils import require_vision +from transformers.utils import is_vision_available + + +if is_vision_available(): + from transformers import AutoTokenizer, LlavaProcessor + + +@require_vision +class LlavaProcessorTest(unittest.TestCase): + def test_can_load_various_tokenizers(self): + for checkpoint in ["Intel/llava-gemma-2b", "llava-hf/llava-1.5-7b-hf"]: + processor = LlavaProcessor.from_pretrained(checkpoint) + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__)