From 363301f2219fff6ad05bd98a51faa1cc510b040e Mon Sep 17 00:00:00 2001 From: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> Date: Fri, 6 Sep 2024 07:49:47 -0400 Subject: [PATCH] support loading model without config.json file (#32356) * support loading model without config.json file * fix condition * update tests * add test * ruff * ruff * ruff --- src/transformers/configuration_utils.py | 4 ++++ .../models/wav2vec2/processing_wav2vec2.py | 2 +- src/transformers/tokenization_utils_base.py | 7 +++++++ src/transformers/utils/hub.py | 4 +++- tests/models/auto/test_configuration_auto.py | 7 ------- tests/models/llama/test_tokenization_llama.py | 10 ++++++++++ tests/utils/test_configuration_utils.py | 6 ++---- 7 files changed, 27 insertions(+), 13 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index b07d2224af..2339c4cd6b 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -565,6 +565,8 @@ class PretrainedConfig(PushToHubMixin): original_kwargs = copy.deepcopy(kwargs) # Get config dict associated with the base config file config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs) + if config_dict is None: + return {}, kwargs if "_commit_hash" in config_dict: original_kwargs["_commit_hash"] = config_dict["_commit_hash"] @@ -635,6 +637,8 @@ class PretrainedConfig(PushToHubMixin): subfolder=subfolder, _commit_hash=commit_hash, ) + if resolved_config_file is None: + return None, kwargs commit_hash = extract_commit_hash(resolved_config_file, commit_hash) except EnvironmentError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to diff --git a/src/transformers/models/wav2vec2/processing_wav2vec2.py b/src/transformers/models/wav2vec2/processing_wav2vec2.py index 1b47c0a988..6fe960c78e 100644 --- a/src/transformers/models/wav2vec2/processing_wav2vec2.py +++ b/src/transformers/models/wav2vec2/processing_wav2vec2.py @@ -51,7 +51,7 @@ class Wav2Vec2Processor(ProcessorMixin): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): try: return super().from_pretrained(pretrained_model_name_or_path, **kwargs) - except OSError: + except (OSError, ValueError): warnings.warn( f"Loading a tokenizer inside {cls.__name__} from a config that does not" " include a `tokenizer_class` attribute is deprecated and will be " diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index dc0af00ced..3b2704498c 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2439,6 +2439,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): "Unable to load vocabulary from file. " "Please check that the provided vocabulary is accessible and not corrupted." ) + except RuntimeError as e: + if "sentencepiece_processor.cc" in str(e): + logger.info( + "Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead." + "(SentencePiece RuntimeError: Tried to load SPM model with non-SPM vocab file).", + ) + return False if added_tokens_decoder != {} and max(list(added_tokens_decoder.keys())[-1], 0) > tokenizer.vocab_size: logger.info( diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 8a75e4c2e8..92be9c0b05 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -370,7 +370,7 @@ def cached_file( if os.path.isdir(path_or_repo_id): resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename) if not os.path.isfile(resolved_file): - if _raise_exceptions_for_missing_entries: + if _raise_exceptions_for_missing_entries and filename not in ["config.json", f"{subfolder}/config.json"]: raise EnvironmentError( f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout " f"'https://huggingface.co/{path_or_repo_id}/tree/{revision}' for available files." @@ -454,6 +454,8 @@ def cached_file( return None if revision is None: revision = "main" + if filename in ["config.json", f"{subfolder}/config.json"]: + return None raise EnvironmentError( f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout " f"'https://huggingface.co/{path_or_repo_id}/tree/{revision}' for available files." diff --git a/tests/models/auto/test_configuration_auto.py b/tests/models/auto/test_configuration_auto.py index c208985ef6..69f9029f90 100644 --- a/tests/models/auto/test_configuration_auto.py +++ b/tests/models/auto/test_configuration_auto.py @@ -104,13 +104,6 @@ class AutoConfigTest(unittest.TestCase): ): _ = AutoConfig.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa") - def test_configuration_not_found(self): - with self.assertRaisesRegex( - EnvironmentError, - "hf-internal-testing/no-config-test-repo does not appear to have a file named config.json.", - ): - _ = AutoConfig.from_pretrained("hf-internal-testing/no-config-test-repo") - def test_from_pretrained_dynamic_config(self): # If remote code is not set, we will time out when asking whether to load the model. with self.assertRaises(ValueError): diff --git a/tests/models/llama/test_tokenization_llama.py b/tests/models/llama/test_tokenization_llama.py index e45149672a..094a511616 100644 --- a/tests/models/llama/test_tokenization_llama.py +++ b/tests/models/llama/test_tokenization_llama.py @@ -20,6 +20,7 @@ import tempfile import unittest from datasets import load_dataset +from huggingface_hub import hf_hub_download from transformers import ( SPIECE_UNDERLINE, @@ -330,6 +331,15 @@ class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): fast_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True), ) + def test_load_tokenizer_with_model_file_only(self): + with tempfile.TemporaryDirectory() as tmp_dir: + hf_hub_download(repo_id="huggyllama/llama-7b", filename="tokenizer.model", local_dir=tmp_dir) + tokenizer_fast = self.rust_tokenizer_class.from_pretrained(tmp_dir) + self.assertEqual(tokenizer_fast.encode("This is a test"), [1, 910, 338, 263, 1243]) + + tokenizer_slow = self.tokenizer_class.from_pretrained(tmp_dir) + self.assertEqual(tokenizer_slow.encode("This is a test"), [1, 910, 338, 263, 1243]) + @require_torch @require_sentencepiece diff --git a/tests/utils/test_configuration_utils.py b/tests/utils/test_configuration_utils.py index 6b684867eb..76394daf9c 100644 --- a/tests/utils/test_configuration_utils.py +++ b/tests/utils/test_configuration_utils.py @@ -247,12 +247,10 @@ class ConfigTestUtils(unittest.TestCase): self.assertEqual(config.text_config.__class__.__name__, "CLIPTextConfig") def test_from_pretrained_subfolder(self): - with self.assertRaises(OSError): - # config is in subfolder, the following should not work without specifying the subfolder - _ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder") + config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder") + self.assertIsNotNone(config) config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder", subfolder="bert") - self.assertIsNotNone(config) def test_cached_files_are_used_when_internet_is_down(self):