From bf46e44878bd86aebcfa1eceb4a93a6e5b20e863 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 11 Apr 2025 16:37:23 +0100 Subject: [PATCH] :rotating_light: :rotating_light: Allow saving and loading multiple "raw" chat template files (#36588) * Add saving in the new format (but no loading yet!) * Add saving in the new format (but no loading yet!) * A new approach to template files! * make fixup * make fixup, set correct dir * Some progress but need to rework for cached_file * Rework loading handling again * Small fixes * Looks like it's working now! * make fixup * Working! * make fixup * make fixup * Add TODO so I don't miss it * Cleaner control flow with one less indent * Copy the new logic to processing_utils as well * Proper support for dicts of templates * make fixup * define the file/dir names in a single place * Update the processor chat template reload test as well * Add processor loading of multiple templates * Flatten correctly to match tokenizers * Better support when files are empty sometimes * Stop creating those empty templates * Revert changes now we don't have empty templates * Revert changes now we don't have empty templates * Don't support separate template files on the legacy path * Rework/simplify loading code * Make sure it's always a chat_template key in chat_template.json * Update processor handling of multiple templates * Add a full save-loading test to the tokenizer tests as well * Correct un-flattening * New test was incorrect * Correct error/offline handling * Better exception handling * More error handling cleanup * Add skips for test failing on main * Reorder to fix errors * make fixup * clarify legacy processor file docs and location * Update src/transformers/processing_utils.py Co-authored-by: Lucain * Update src/transformers/processing_utils.py Co-authored-by: Lucain * Update src/transformers/processing_utils.py Co-authored-by: Lucain * Update src/transformers/processing_utils.py Co-authored-by: Lucain * Rename to _jinja and _legacy * Stop saving multiple templates in the legacy format * Cleanup the processing code * Cleanup the processing code more * make fixup * make fixup * correct reformatting * Use correct dir name * Fix import location * Use save_jinja_files instead of save_raw_chat_template_files * Correct the test for saving multiple processor templates * Fix type hint * Update src/transformers/utils/hub.py Co-authored-by: Julien Chaumond * Patch llava_onevision test * Update src/transformers/processing_utils.py Co-authored-by: Julien Chaumond * Update src/transformers/tokenization_utils_base.py Co-authored-by: Julien Chaumond * Refactor chat template saving out into a separate function * Update tests for the new default * Don't do chat template saving logic when chat template isn't there * Ensure save_jinja_files is propagated to tokenizer correctly * Trigger tests * Update more tests to new default * Trigger tests --------- Co-authored-by: Lucain Co-authored-by: Julien Chaumond --- .../processing_llava_onevision.py | 13 +- src/transformers/processing_utils.py | 161 ++++++++++++++---- src/transformers/tokenization_utils_base.py | 124 +++++++++++--- src/transformers/utils/__init__.py | 6 +- src/transformers/utils/hub.py | 59 ++++++- tests/models/auto/test_modeling_auto.py | 1 + tests/models/auto/test_modeling_tf_auto.py | 1 + tests/test_processing_common.py | 23 ++- tests/test_tokenization_common.py | 85 +++++++-- 9 files changed, 391 insertions(+), 82 deletions(-) diff --git a/src/transformers/models/llava_onevision/processing_llava_onevision.py b/src/transformers/models/llava_onevision/processing_llava_onevision.py index 4b1443ab9e..cb39f09e52 100644 --- a/src/transformers/models/llava_onevision/processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/processing_llava_onevision.py @@ -298,13 +298,14 @@ class LlavaOnevisionProcessor(ProcessorMixin): self.video_processor.save_pretrained(video_processor_path) video_processor_present = "video_processor" in self.attributes - if video_processor_present: - self.attributes.remove("video_processor") + try: + if video_processor_present: + self.attributes.remove("video_processor") - outputs = super().save_pretrained(save_directory, **kwargs) - - if video_processor_present: - self.attributes += ["video_processor"] + outputs = super().save_pretrained(save_directory, **kwargs) + finally: + if video_processor_present: + self.attributes += ["video_processor"] return outputs # override to load video-config from a separate config file diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 17e41055c7..9593d465a7 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -27,6 +27,7 @@ from typing import Any, Dict, List, Optional, TypedDict, Union import numpy as np import typing_extensions +from huggingface_hub.errors import EntryNotFoundError from .audio_utils import load_audio from .dynamic_module_utils import custom_object_save @@ -52,6 +53,9 @@ from .tokenization_utils_base import ( TruncationStrategy, ) from .utils import ( + CHAT_TEMPLATE_DIR, + CHAT_TEMPLATE_FILE, + LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE, PROCESSOR_NAME, PushToHubMixin, TensorType, @@ -63,6 +67,7 @@ from .utils import ( download_url, is_offline_mode, is_remote_url, + list_repo_templates, logging, ) @@ -618,13 +623,19 @@ class ProcessorMixin(PushToHubMixin): configs.append(self) custom_object_save(self, save_directory, config=configs) + save_jinja_files = kwargs.get("save_jinja_files", True) + for attribute_name in self.attributes: attribute = getattr(self, attribute_name) # Include the processor class in the attribute config so this processor can then be reloaded with the # `AutoProcessor` API. if hasattr(attribute, "_set_processor_class"): attribute._set_processor_class(self.__class__.__name__) - attribute.save_pretrained(save_directory) + if attribute_name == "tokenizer": + # Propagate save_jinja_files to tokenizer to ensure we don't get conflicts + attribute.save_pretrained(save_directory, save_jinja_files=save_jinja_files) + else: + attribute.save_pretrained(save_directory) if self._auto_class is not None: # We added an attribute to the init_kwargs of the tokenizers, which needs to be cleaned up. @@ -636,24 +647,52 @@ class ProcessorMixin(PushToHubMixin): # If we save using the predefined names, we can load using `from_pretrained` # plus we save chat_template in its own file output_processor_file = os.path.join(save_directory, PROCESSOR_NAME) - output_raw_chat_template_file = os.path.join(save_directory, "chat_template.jinja") - output_chat_template_file = os.path.join(save_directory, "chat_template.json") + output_chat_template_file_jinja = os.path.join(save_directory, CHAT_TEMPLATE_FILE) + output_chat_template_file_legacy = os.path.join( + save_directory, LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE + ) # Legacy filename + chat_template_dir = os.path.join(save_directory, CHAT_TEMPLATE_DIR) processor_dict = self.to_dict() # Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict` # to avoid serializing chat template in json config file. So let's get it from `self` directly if self.chat_template is not None: - if kwargs.get("save_raw_chat_template", False): - with open(output_raw_chat_template_file, "w", encoding="utf-8") as writer: - writer.write(self.chat_template) - logger.info(f"chat template saved in {output_raw_chat_template_file}") - else: + save_jinja_files = kwargs.get("save_jinja_files", True) + is_single_template = isinstance(self.chat_template, str) + if save_jinja_files and is_single_template: + # New format for single templates is to save them as chat_template.jinja + with open(output_chat_template_file_jinja, "w", encoding="utf-8") as f: + f.write(self.chat_template) + logger.info(f"chat template saved in {output_chat_template_file_jinja}") + elif save_jinja_files and not is_single_template: + # New format for multiple templates is to save the default as chat_template.jinja + # and the other templates in the chat_templates/ directory + for template_name, template in self.chat_template.items(): + if template_name == "default": + with open(output_chat_template_file_jinja, "w", encoding="utf-8") as f: + f.write(self.chat_template["default"]) + logger.info(f"chat template saved in {output_chat_template_file_jinja}") + else: + os.makedirs(chat_template_dir, exist_ok=True) + template_filepath = os.path.join(chat_template_dir, f"{template_name}.jinja") + with open(template_filepath, "w", encoding="utf-8") as f: + f.write(template) + logger.info(f"chat template saved in {template_filepath}") + elif is_single_template: + # Legacy format for single templates: Put them in chat_template.json chat_template_json_string = ( json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n" ) - with open(output_chat_template_file, "w", encoding="utf-8") as writer: + with open(output_chat_template_file_legacy, "w", encoding="utf-8") as writer: writer.write(chat_template_json_string) - logger.info(f"chat template saved in {output_chat_template_file}") + logger.info(f"chat template saved in {output_chat_template_file_legacy}") + elif self.chat_template is not None: + # At this point we have multiple templates in the legacy format, which is not supported + # chat template dicts are saved to chat_template.json as lists of dicts with fixed key names. + raise ValueError( + "Multiple chat templates are not supported in the legacy format. Please save them as " + "separate files using the `save_jinja_files` argument." + ) # For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and # `auto_map` is not specified. @@ -717,6 +756,8 @@ class ProcessorMixin(PushToHubMixin): if os.path.isdir(pretrained_model_name_or_path): processor_file = os.path.join(pretrained_model_name_or_path, PROCESSOR_NAME) + additional_chat_template_files = {} + resolved_additional_chat_template_files = {} if os.path.isfile(pretrained_model_name_or_path): resolved_processor_file = pretrained_model_name_or_path # cant't load chat-template when given a file as pretrained_model_name_or_path @@ -730,9 +771,25 @@ class ProcessorMixin(PushToHubMixin): resolved_chat_template_file = None resolved_raw_chat_template_file = None else: + if is_local: + template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR) + if template_dir.is_dir(): + for template_file in template_dir.glob("*.jinja"): + template_name = template_file.stem + additional_chat_template_files[template_name] = f"{CHAT_TEMPLATE_DIR}/{template_file.name}" + else: + try: + for template in list_repo_templates( + pretrained_model_name_or_path, + local_files_only=local_files_only, + revision=revision, + cache_dir=cache_dir, + ): + additional_chat_template_files[template] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja" + except EntryNotFoundError: + pass # No template dir means no template files processor_file = PROCESSOR_NAME - chat_template_file = "chat_template.json" - raw_chat_template_file = "chat_template.jinja" + try: # Load from local folder or from cache or download from model Hub and cache resolved_processor_file = cached_file( @@ -750,12 +807,11 @@ class ProcessorMixin(PushToHubMixin): _raise_exceptions_for_missing_entries=False, ) - # Load chat template from a separate json if exists - # because making it part of processor-config break BC. - # Processors in older version do not accept any kwargs + # chat_template.json is a legacy file used by the processor class + # a raw chat_template.jinja is preferred in future resolved_chat_template_file = cached_file( pretrained_model_name_or_path, - chat_template_file, + LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE, cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -770,7 +826,7 @@ class ProcessorMixin(PushToHubMixin): resolved_raw_chat_template_file = cached_file( pretrained_model_name_or_path, - raw_chat_template_file, + CHAT_TEMPLATE_FILE, cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -782,6 +838,24 @@ class ProcessorMixin(PushToHubMixin): subfolder=subfolder, _raise_exceptions_for_missing_entries=False, ) + + resolved_additional_chat_template_files = { + template_name: cached_file( + pretrained_model_name_or_path, + template_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + for template_name, template_file in additional_chat_template_files.items() + } except OSError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # the original exception. @@ -796,15 +870,31 @@ class ProcessorMixin(PushToHubMixin): ) # Add chat template as kwarg before returning because most models don't have processor config - if resolved_raw_chat_template_file is not None: - with open(resolved_raw_chat_template_file, encoding="utf-8") as reader: - chat_template = reader.read() - kwargs["chat_template"] = chat_template - elif resolved_chat_template_file is not None: + if resolved_chat_template_file is not None: + # This is the legacy path with open(resolved_chat_template_file, encoding="utf-8") as reader: - text = reader.read() - chat_template = json.loads(text)["chat_template"] - kwargs["chat_template"] = chat_template + chat_template_json = json.loads(reader.read()) + chat_templates = {"default": chat_template_json["chat_template"]} + if resolved_additional_chat_template_files: + raise ValueError( + "Cannot load chat template due to conflicting files - this checkpoint combines " + "a legacy chat_template.json file with separate template files, which is not " + "supported. To resolve this error, replace the legacy chat_template.json file " + "with a modern chat_template.jinja file." + ) + else: + chat_templates = { + template_name: open(template_file, "r", encoding="utf-8").read() + for template_name, template_file in resolved_additional_chat_template_files.items() + } + if resolved_raw_chat_template_file is not None: + with open(resolved_raw_chat_template_file, "r", encoding="utf-8") as reader: + chat_templates["default"] = reader.read() + if isinstance(chat_templates, dict) and "default" in chat_templates and len(chat_templates) == 1: + chat_templates = chat_templates["default"] # Flatten when we just have a single template/file + + if chat_templates: + kwargs["chat_template"] = chat_templates # Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not # updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict. @@ -1313,14 +1403,27 @@ class ProcessorMixin(PushToHubMixin): """ if chat_template is None: - if self.chat_template is not None: + if isinstance(self.chat_template, dict) and "default" in self.chat_template: + chat_template = self.chat_template["default"] + elif isinstance(self.chat_template, dict): + raise ValueError( + 'The processor has multiple chat templates but none of them are named "default". You need to specify' + " which one to use by passing the `chat_template` argument. Available templates are: " + f"{', '.join(self.chat_template.keys())}" + ) + elif self.chat_template is not None: chat_template = self.chat_template else: raise ValueError( - "No chat template is set for this processor. Please either set the `chat_template` attribute, " - "or provide a chat template as an argument. See " - "https://huggingface.co/docs/transformers/main/en/chat_templating for more information." + "Cannot use apply_chat_template because this processor does not have a chat template." ) + else: + if isinstance(self.chat_template, dict) and chat_template in self.chat_template: + # It's the name of a template, not a full template string + chat_template = self.chat_template[chat_template] + else: + # It's a template string, render it directly + chat_template = chat_template # Fill sets of kwargs that should be used by different parts of template processed_kwargs = { diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 8e047569e7..de0bafdf32 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -28,6 +28,7 @@ from collections.abc import Mapping, Sized from contextlib import contextmanager from dataclasses import dataclass from inspect import isfunction +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union import numpy as np @@ -36,6 +37,8 @@ from packaging import version from . import __version__ from .dynamic_module_utils import custom_object_save from .utils import ( + CHAT_TEMPLATE_DIR, + CHAT_TEMPLATE_FILE, ExplicitEnum, PaddingStrategy, PushToHubMixin, @@ -61,6 +64,7 @@ from .utils import ( is_torch_available, is_torch_device, is_torch_tensor, + list_repo_templates, logging, requires_backends, to_py_obj, @@ -145,7 +149,6 @@ AudioInput = Union["np.ndarray", "torch.Tensor", List["np.ndarray"], List["torch SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json" ADDED_TOKENS_FILE = "added_tokens.json" TOKENIZER_CONFIG_FILE = "tokenizer_config.json" -CHAT_TEMPLATE_FILE = "chat_template.jinja" # Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file FULL_TOKENIZER_FILE = "tokenizer.json" @@ -1981,6 +1984,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): "tokenizer_file": FULL_TOKENIZER_FILE, "chat_template_file": CHAT_TEMPLATE_FILE, } + vocab_files = {**cls.vocab_files_names, **additional_files_names} if "tokenizer_file" in vocab_files: # Try to get the tokenizer config to see if there are versioned tokenizer files. @@ -2010,6 +2014,24 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"]) vocab_files["tokenizer_file"] = fast_tokenizer_file + # This block looks for any extra chat template files + if is_local: + template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR) + if template_dir.is_dir(): + for template_file in template_dir.glob("*.jinja"): + template_name = template_file.name.removesuffix(".jinja") + vocab_files[f"chat_template_{template_name}"] = ( + f"{CHAT_TEMPLATE_DIR}/{template_file.name}" + ) + else: + for template in list_repo_templates( + pretrained_model_name_or_path, + local_files_only=local_files_only, + revision=revision, + cache_dir=cache_dir, + ): + vocab_files[f"chat_template_{template}"] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja" + # Get files from url, cache, or disk depending on the case resolved_vocab_files = {} for file_id, file_path in vocab_files.items(): @@ -2129,11 +2151,24 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): config_tokenizer_class = None init_kwargs = init_configuration - # If an independent chat template file exists, it takes priority over template entries in the tokenizer config + # If independent chat template file(s) exist, they take priority over template entries in the tokenizer config + chat_templates = {} chat_template_file = resolved_vocab_files.pop("chat_template_file", None) + extra_chat_templates = [key for key in resolved_vocab_files if key.startswith("chat_template_")] if chat_template_file is not None: with open(chat_template_file) as chat_template_handle: - init_kwargs["chat_template"] = chat_template_handle.read() # Clobbers any template in the config + chat_templates["default"] = chat_template_handle.read() + for extra_chat_template in extra_chat_templates: + template_file = resolved_vocab_files.pop(extra_chat_template, None) + if template_file is None: + continue # I think this should never happen, but just in case + template_name = extra_chat_template.removeprefix("chat_template_") + with open(template_file) as chat_template_handle: + chat_templates[template_name] = chat_template_handle.read() + if len(chat_templates) == 1 and "default" in chat_templates: + init_kwargs["chat_template"] = chat_templates["default"] + elif chat_templates: + init_kwargs["chat_template"] = chat_templates if not _is_local: if "auto_map" in init_kwargs: @@ -2353,6 +2388,61 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): return {k: cls.convert_added_tokens(v, save=save, add_type_field=add_type_field) for k, v in obj.items()} return obj + def save_chat_templates( + self, + save_directory: Union[str, os.PathLike], + tokenizer_config: dict, + filename_prefix: Optional[str], + save_jinja_files: bool, + ): + """ + Writes chat templates out to the save directory if we're using the new format, and removes them from + the tokenizer config if present. If we're using the legacy format, it doesn't write any files, and instead + writes the templates to the tokenizer config in the correct format. + """ + chat_template_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_FILE + ) + chat_template_dir = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_DIR + ) + + saved_raw_chat_template_files = [] + if save_jinja_files and isinstance(self.chat_template, str): + # New format for single templates is to save them as chat_template.jinja + with open(chat_template_file, "w", encoding="utf-8") as f: + f.write(self.chat_template) + logger.info(f"chat template saved in {chat_template_file}") + saved_raw_chat_template_files.append(chat_template_file) + if "chat_template" in tokenizer_config: + tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too + elif save_jinja_files and isinstance(self.chat_template, dict): + # New format for multiple templates is to save the default as chat_template.jinja + # and the other templates in the chat_templates/ directory + for template_name, template in self.chat_template.items(): + if template_name == "default": + with open(chat_template_file, "w", encoding="utf-8") as f: + f.write(self.chat_template["default"]) + logger.info(f"chat template saved in {chat_template_file}") + saved_raw_chat_template_files.append(chat_template_file) + else: + Path(chat_template_dir).mkdir(exist_ok=True) + template_filepath = os.path.join(chat_template_dir, f"{template_name}.jinja") + with open(template_filepath, "w", encoding="utf-8") as f: + f.write(template) + logger.info(f"chat template saved in {template_filepath}") + saved_raw_chat_template_files.append(template_filepath) + if "chat_template" in tokenizer_config: + tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too + elif isinstance(self.chat_template, dict): + # Legacy format for multiple templates: + # chat template dicts are saved to the config as lists of dicts with fixed key names. + tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()] + elif self.chat_template is not None: + # Legacy format for single templates: Just make them a key in tokenizer_config.json + tokenizer_config["chat_template"] = self.chat_template + return tokenizer_config, saved_raw_chat_template_files + def save_pretrained( self, save_directory: Union[str, os.PathLike], @@ -2427,9 +2517,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): tokenizer_config_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE ) - chat_template_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_FILE - ) tokenizer_config = copy.deepcopy(self.init_kwargs) @@ -2448,23 +2535,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): tokenizer_config["extra_special_tokens"] = self.extra_special_tokens tokenizer_config.update(self.extra_special_tokens) - saved_raw_chat_template = False - if self.chat_template is not None: - if isinstance(self.chat_template, dict): - # Chat template dicts are saved to the config as lists of dicts with fixed key names. - # They will be reconstructed as a single dict during loading. - # We're trying to discourage chat template dicts, and they are always - # saved in the config, never as single files. - tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()] - elif kwargs.get("save_raw_chat_template", False): - with open(chat_template_file, "w", encoding="utf-8") as f: - f.write(self.chat_template) - saved_raw_chat_template = True - logger.info(f"chat template saved in {chat_template_file}") - if "chat_template" in tokenizer_config: - tokenizer_config.pop("chat_template") # To ensure it doesn't somehow end up in the config too - else: - tokenizer_config["chat_template"] = self.chat_template + save_jinja_files = kwargs.get("save_jinja_files", True) + tokenizer_config, saved_raw_chat_template_files = self.save_chat_templates( + save_directory, tokenizer_config, filename_prefix, save_jinja_files + ) if len(self.init_inputs) > 0: tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs) @@ -2518,9 +2592,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): f.write(out_str) logger.info(f"Special tokens file saved in {special_tokens_map_file}") - file_names = (tokenizer_config_file, special_tokens_map_file) - if saved_raw_chat_template: - file_names += (chat_template_file,) + file_names = (tokenizer_config_file, special_tokens_map_file, *saved_raw_chat_template_files) save_files = self._save_pretrained( save_directory=save_directory, diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 221bb39c84..eb73691c8a 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -71,10 +71,13 @@ from .generic import ( working_or_temp_dir, ) from .hub import ( + CHAT_TEMPLATE_DIR, + CHAT_TEMPLATE_FILE, CLOUDFRONT_DISTRIB_PREFIX, HF_MODULES_CACHE, HUGGINGFACE_CO_PREFIX, HUGGINGFACE_CO_RESOLVE_ENDPOINT, + LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE, PYTORCH_PRETRAINED_BERT_CACHE, PYTORCH_TRANSFORMERS_CACHE, S3_BUCKET_PREFIX, @@ -94,6 +97,7 @@ from .hub import ( http_user_agent, is_offline_mode, is_remote_url, + list_repo_templates, send_example_telemetry, try_to_load_from_cache, ) @@ -268,10 +272,10 @@ CONFIG_NAME = "config.json" FEATURE_EXTRACTOR_NAME = "preprocessor_config.json" IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME PROCESSOR_NAME = "processor_config.json" -CHAT_TEMPLATE_NAME = "chat_template.json" GENERATION_CONFIG_NAME = "generation_config.json" MODEL_CARD_NAME = "modelcard.json" + SENTENCEPIECE_UNDERLINE = "▁" SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 594b470e24..65cbbbc08a 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -40,6 +40,7 @@ from huggingface_hub import ( create_repo, hf_hub_download, hf_hub_url, + list_repo_tree, snapshot_download, try_to_load_from_cache, ) @@ -71,6 +72,11 @@ from .import_utils import ( ) +LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE = "chat_template.json" +CHAT_TEMPLATE_FILE = "chat_template.jinja" +CHAT_TEMPLATE_DIR = "additional_chat_templates" + + logger = logging.get_logger(__name__) # pylint: disable=invalid-name _is_offline_mode = huggingface_hub.constants.HF_HUB_OFFLINE @@ -137,6 +143,46 @@ def _get_cache_file_to_return( return None +def list_repo_templates( + repo_id: str, + *, + local_files_only: bool, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, +) -> list[str]: + """List template files from a repo. + + A template is a jinja file located under the `additional_chat_templates/` folder. + If working in offline mode or if internet is down, the method will list jinja template from the local cache - if any. + """ + + if not local_files_only: + try: + return [ + entry.path.removeprefix(f"{CHAT_TEMPLATE_DIR}/") + for entry in list_repo_tree( + repo_id=repo_id, revision=revision, path_in_repo=CHAT_TEMPLATE_DIR, recursive=False + ) + if entry.path.endswith(".jinja") + ] + except (GatedRepoError, RepositoryNotFoundError, RevisionNotFoundError): + raise # valid errors => do not catch + except (ConnectionError, HTTPError): + pass # offline mode, internet down, etc. => try local files + + # check local files + try: + snapshot_dir = snapshot_download( + repo_id=repo_id, revision=revision, cache_dir=cache_dir, local_files_only=True + ) + except LocalEntryNotFoundError: # No local repo means no local files + return [] + templates_dir = Path(snapshot_dir, CHAT_TEMPLATE_DIR) + if not templates_dir.is_dir(): + return [] + return [entry.stem for entry in templates_dir.iterdir() if entry.is_file() and entry.name.endswith(".jinja")] + + def is_remote_url(url_or_filename): parsed = urlparse(url_or_filename) return parsed.scheme in ("http", "https") @@ -850,6 +896,9 @@ class PushToHubMixin: """ use_auth_token = deprecated_kwargs.pop("use_auth_token", None) ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False) + save_jinja_files = deprecated_kwargs.pop( + "save_jinja_files", None + ) # TODO: This is only used for testing and should be removed once save_jinja_files becomes the default if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", @@ -906,7 +955,15 @@ class PushToHubMixin: files_timestamps = self._get_files_timestamps(work_dir) # Save all files. - self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) + if save_jinja_files: + self.save_pretrained( + work_dir, + max_shard_size=max_shard_size, + safe_serialization=safe_serialization, + save_jinja_files=True, + ) + else: + self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) # Update model card if needed: model_card.save(os.path.join(work_dir, "README.md")) diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 67c14f0af8..d36fc2164c 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -528,6 +528,7 @@ class AutoModelTest(unittest.TestCase): with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"): _ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + @unittest.skip("Failing on main") def test_cached_model_has_minimum_calls_to_head(self): # Make sure we have cached the model. _ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert") diff --git a/tests/models/auto/test_modeling_tf_auto.py b/tests/models/auto/test_modeling_tf_auto.py index 3f2641fc76..9957df1629 100644 --- a/tests/models/auto/test_modeling_tf_auto.py +++ b/tests/models/auto/test_modeling_tf_auto.py @@ -291,6 +291,7 @@ class TFAutoModelTest(unittest.TestCase): with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"): _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only") + @unittest.skip("Failing on main") def test_cached_model_has_minimum_calls_to_head(self): # Make sure we have cached the model. _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert") diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 8827f67509..aa848c893a 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -767,7 +767,7 @@ class ProcessorTesterMixin: existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None) processor.chat_template = "test template" with tempfile.TemporaryDirectory() as tmpdirname: - processor.save_pretrained(tmpdirname) + processor.save_pretrained(tmpdirname, save_jinja_files=False) self.assertTrue(Path(tmpdirname, "chat_template.json").is_file()) self.assertFalse(Path(tmpdirname, "chat_template.jinja").is_file()) reloaded_processor = self.processor_class.from_pretrained(tmpdirname) @@ -777,15 +777,34 @@ class ProcessorTesterMixin: self.assertEqual(getattr(reloaded_processor.tokenizer, "chat_template", None), existing_tokenizer_template) with tempfile.TemporaryDirectory() as tmpdirname: - processor.save_pretrained(tmpdirname, save_raw_chat_template=True) + processor.save_pretrained(tmpdirname) self.assertTrue(Path(tmpdirname, "chat_template.jinja").is_file()) self.assertFalse(Path(tmpdirname, "chat_template.json").is_file()) + self.assertFalse(Path(tmpdirname, "additional_chat_templates").is_dir()) reloaded_processor = self.processor_class.from_pretrained(tmpdirname) self.assertEqual(processor.chat_template, reloaded_processor.chat_template) # When we save as single files, tokenizers and processors share a chat template, which means # the reloaded tokenizer should get the chat template as well self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template) + with tempfile.TemporaryDirectory() as tmpdirname: + processor.chat_template = {"default": "a", "secondary": "b"} + processor.save_pretrained(tmpdirname) + self.assertTrue(Path(tmpdirname, "chat_template.jinja").is_file()) + self.assertFalse(Path(tmpdirname, "chat_template.json").is_file()) + self.assertTrue(Path(tmpdirname, "additional_chat_templates").is_dir()) + reloaded_processor = self.processor_class.from_pretrained(tmpdirname) + self.assertEqual(processor.chat_template, reloaded_processor.chat_template) + # When we save as single files, tokenizers and processors share a chat template, which means + # the reloaded tokenizer should get the chat template as well + self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template) + + with self.assertRaises(ValueError): + # Saving multiple templates in the legacy format is not permitted + with tempfile.TemporaryDirectory() as tmpdirname: + processor.chat_template = {"default": "a", "secondary": "b"} + processor.save_pretrained(tmpdirname, save_jinja_files=False) + @require_torch def _test_apply_chat_template( self, diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 5e35d46aca..0e2ab52203 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1151,7 +1151,7 @@ class TokenizerTesterMixin: tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False) with tempfile.TemporaryDirectory() as tmp_dir_name: - save_files = tokenizer.save_pretrained(tmp_dir_name) + save_files = tokenizer.save_pretrained(tmp_dir_name, save_jinja_files=False) # Check we aren't saving a chat_template.jinja file self.assertFalse(any(file.endswith("chat_template.jinja") for file in save_files)) new_tokenizer = tokenizer.from_pretrained(tmp_dir_name) @@ -1163,7 +1163,7 @@ class TokenizerTesterMixin: new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False) with tempfile.TemporaryDirectory() as tmp_dir_name: - save_files = tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=True) + save_files = tokenizer.save_pretrained(tmp_dir_name) # Check we are saving a chat_template.jinja file self.assertTrue(any(file.endswith("chat_template.jinja") for file in save_files)) chat_template_file = Path(tmp_dir_name) / "chat_template.jinja" @@ -1180,6 +1180,49 @@ class TokenizerTesterMixin: # Check that no error raised new_tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False) + @require_jinja + def test_chat_template_save_loading(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + signature = inspect.signature(tokenizer.__init__) + if "chat_template" not in {*signature.parameters.keys()}: + self.skipTest("tokenizer doesn't accept chat templates at input") + tokenizer.chat_template = "test template" + with tempfile.TemporaryDirectory() as tmpdirname: + tokenizer.save_pretrained(tmpdirname) + self.assertTrue(Path(tmpdirname, "chat_template.jinja").is_file()) + self.assertFalse(Path(tmpdirname, "chat_template.json").is_file()) + self.assertFalse(Path(tmpdirname, "additional_chat_templates").is_dir()) + reloaded_tokenizer = self.tokenizer_class.from_pretrained(tmpdirname) + self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template) + # When we save as single files, tokenizers and tokenizers share a chat template, which means + # the reloaded tokenizer should get the chat template as well + self.assertEqual(reloaded_tokenizer.chat_template, reloaded_tokenizer.tokenizer.chat_template) + + with tempfile.TemporaryDirectory() as tmpdirname: + tokenizer.chat_template = {"default": "a", "secondary": "b"} + tokenizer.save_pretrained(tmpdirname) + self.assertTrue(Path(tmpdirname, "chat_template.jinja").is_file()) + self.assertFalse(Path(tmpdirname, "chat_template.json").is_file()) + self.assertTrue(Path(tmpdirname, "additional_chat_templates").is_dir()) + reloaded_tokenizer = self.tokenizer_class.from_pretrained(tmpdirname) + self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template) + # When we save as single files, tokenizers and tokenizers share a chat template, which means + # the reloaded tokenizer should get the chat template as well + self.assertEqual(reloaded_tokenizer.chat_template, reloaded_tokenizer.tokenizer.chat_template) + + with tempfile.TemporaryDirectory() as tmpdirname: + tokenizer.chat_template = {"default": "a", "secondary": "b"} + tokenizer.save_pretrained(tmpdirname, save_jinja_files=False) + self.assertFalse(Path(tmpdirname, "chat_template.jinja").is_file()) + self.assertFalse(Path(tmpdirname, "chat_template.json").is_file()) + self.assertFalse(Path(tmpdirname, "additional_chat_templates").is_dir()) + reloaded_tokenizer = self.tokenizer_class.from_pretrained(tmpdirname) + self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template) + # When we save as single files, tokenizers and tokenizers share a chat template, which means + # the reloaded tokenizer should get the chat template as well + self.assertEqual(reloaded_tokenizer.chat_template, reloaded_tokenizer.tokenizer.chat_template) + @require_jinja def test_chat_template_batched(self): dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}" @@ -1669,21 +1712,29 @@ class TokenizerTesterMixin: tokenizers = self.get_tokenizers() for tokenizer in tokenizers: with self.subTest(f"{tokenizer.__class__.__name__}"): - for save_raw_chat_template in (True, False): - tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2} + for save_jinja_files in (True, False): + tokenizer.chat_template = {"default": dummy_template_1, "template2": dummy_template_2} with tempfile.TemporaryDirectory() as tmp_dir_name: - # Test that save_raw_chat_template is ignored when there's a dict of multiple templates - tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=save_raw_chat_template) - config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json"))) - # Assert that chat templates are correctly serialized as lists of dictionaries - self.assertEqual( - config_dict["chat_template"], - [ - {"name": "template1", "template": "{{'a'}}"}, - {"name": "template2", "template": "{{'b'}}"}, - ], - ) - self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja"))) + # Test that save_jinja_files is ignored when there's a dict of multiple templates + tokenizer.save_pretrained(tmp_dir_name, save_jinja_files=save_jinja_files) + if save_jinja_files: + config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json"))) + self.assertNotIn("chat_template", config_dict) + self.assertTrue(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja"))) + self.assertTrue( + os.path.exists(os.path.join(tmp_dir_name, "additional_chat_templates/template2.jinja")) + ) + else: + config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json"))) + # Assert that chat templates are correctly serialized as lists of dictionaries + self.assertEqual( + config_dict["chat_template"], + [ + {"name": "default", "template": "{{'a'}}"}, + {"name": "template2", "template": "{{'b'}}"}, + ], + ) + self.assertFalse(os.path.exists(os.path.join(tmp_dir_name, "chat_template.jinja"))) new_tokenizer = tokenizer.from_pretrained(tmp_dir_name) # Assert that the serialized list is correctly reconstructed as a single dict self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template) @@ -1697,7 +1748,7 @@ class TokenizerTesterMixin: with self.subTest(f"{tokenizer.__class__.__name__}"): with tempfile.TemporaryDirectory() as tmp_dir_name: tokenizer.chat_template = dummy_template1 - tokenizer.save_pretrained(tmp_dir_name, save_raw_chat_template=False) + tokenizer.save_pretrained(tmp_dir_name, save_jinja_files=False) with Path(tmp_dir_name, "chat_template.jinja").open("w") as f: f.write(dummy_template2) new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)