diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index 5cdecfacad..44fd662304 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -55,12 +55,10 @@ When you load a model explicitly, you can inspect the generation configuration t >>> from transformers import AutoModelForCausalLM >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") ->>> model.generation_config # doctest: +IGNORE_RESULT +>>> model.generation_config GenerationConfig { - "_from_model_config": true, "bos_token_id": 50256, "eos_token_id": 50256, - "transformers_version": "4.26.0.dev0" } ``` diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index ef0963f675..94d0f823ed 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -34,6 +34,7 @@ from ..utils import ( logger = logging.get_logger(__name__) +METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version") class GenerationConfig(PushToHubMixin): @@ -315,20 +316,19 @@ class GenerationConfig(PushToHubMixin): # Validate the values of the attributes self.validate(is_init=True) + def __hash__(self): + return hash(self.to_json_string(ignore_metadata=True)) + def __eq__(self, other): if not isinstance(other, GenerationConfig): return False - self_dict = self.__dict__.copy() - other_dict = other.__dict__.copy() - # ignore metadata - for metadata_field in ("_from_model_config", "_commit_hash", "transformers_version"): - self_dict.pop(metadata_field, None) - other_dict.pop(metadata_field, None) - return self_dict == other_dict + self_without_metadata = self.to_json_string(use_diff=False, ignore_metadata=True) + other_without_metadata = other.to_json_string(use_diff=False, ignore_metadata=True) + return self_without_metadata == other_without_metadata def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" + return f"{self.__class__.__name__} {self.to_json_string(ignore_metadata=True)}" def validate(self, is_init=False): """ @@ -729,7 +729,9 @@ class GenerationConfig(PushToHubMixin): else: logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}") - return cls.from_dict(config_dict, **kwargs) + config = cls.from_dict(config_dict, **kwargs) + config._original_object_hash = hash(config) # Hash to detect whether the instance was modified + return config @classmethod def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): @@ -814,8 +816,12 @@ class GenerationConfig(PushToHubMixin): `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. """ output = copy.deepcopy(self.__dict__) + + # Fields to ignore at serialization time if "_commit_hash" in output: del output["_commit_hash"] + if "_original_object_hash" in output: + del output["_original_object_hash"] # Transformers version when serializing this file output["transformers_version"] = __version__ @@ -823,7 +829,7 @@ class GenerationConfig(PushToHubMixin): self.dict_torch_dtype_to_str(output) return output - def to_json_string(self, use_diff: bool = True) -> str: + def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -> str: """ Serializes this instance to a JSON string. @@ -831,6 +837,8 @@ class GenerationConfig(PushToHubMixin): use_diff (`bool`, *optional*, defaults to `True`): If set to `True`, only the difference between the config instance and the default `GenerationConfig()` is serialized to JSON string. + ignore_metadata (`bool`, *optional*, defaults to `False`): + Whether to ignore the metadata fields present in the instance Returns: `str`: String containing all the attributes that make up this configuration instance in JSON format. @@ -839,6 +847,11 @@ class GenerationConfig(PushToHubMixin): config_dict = self.to_diff_dict() else: config_dict = self.to_dict() + + if ignore_metadata: + for metadata_field in METADATA_FIELDS: + config_dict.pop(metadata_field, None) + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True): @@ -882,6 +895,7 @@ class GenerationConfig(PushToHubMixin): if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr): setattr(config, attr, decoder_config[attr]) + config._original_object_hash = hash(config) # Hash to detect whether the instance was modified return config def update(self, **kwargs): diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 284e0f51cd..40dfaacf58 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -310,16 +310,20 @@ class FlaxGenerationMixin: # priority: `generation_config` argument > `model.generation_config` (the default generation config) if generation_config is None: - # legacy: users may modify the model configuration to control generation -- update the generation config - # model attribute accordingly, if it was created from the model config - if self.generation_config._from_model_config: + # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, + # two conditions must be met + # 1) the generation config must have been created from the model config (`_from_model_config` field); + # 2) the generation config must have seen no modification since its creation (the hash is the same). + if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash( + self.generation_config + ): new_generation_config = GenerationConfig.from_model_config(self.config) if new_generation_config != self.generation_config: warnings.warn( "You have modified the pretrained model configuration to control generation. This is a" " deprecated strategy to control generation and will be removed soon, in a future version." - " Please use a generation configuration file (see" - " https://huggingface.co/docs/transformers/main_classes/text_generation )" + " Please use and modify the model generation configuration (see" + " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" ) self.generation_config = new_generation_config generation_config = self.generation_config diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index df392cf5ca..c12a84b98c 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -716,16 +716,20 @@ class TFGenerationMixin: # priority: `generation_config` argument > `model.generation_config` (the default generation config) if generation_config is None: - # legacy: users may modify the model configuration to control generation -- update the generation config - # model attribute accordingly, if it was created from the model config - if self.generation_config._from_model_config: + # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, + # two conditions must be met + # 1) the generation config must have been created from the model config (`_from_model_config` field); + # 2) the generation config must have seen no modification since its creation (the hash is the same). + if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash( + self.generation_config + ): new_generation_config = GenerationConfig.from_model_config(self.config) if new_generation_config != self.generation_config: warnings.warn( "You have modified the pretrained model configuration to control generation. This is a" " deprecated strategy to control generation and will be removed soon, in a future version." - " Please use a generation configuration file (see" - " https://huggingface.co/docs/transformers/main_classes/text_generation )" + " Please use and modify the model generation configuration (see" + " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" ) self.generation_config = new_generation_config generation_config = self.generation_config diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index dab69fc943..a7501d43b7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1409,16 +1409,20 @@ class GenerationMixin: # priority: `generation_config` argument > `model.generation_config` (the default generation config) if generation_config is None: - # legacy: users may modify the model configuration to control generation -- update the generation config - # model attribute accordingly, if it was created from the model config - if self.generation_config._from_model_config: + # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, + # two conditions must be met + # 1) the generation config must have been created from the model config (`_from_model_config` field); + # 2) the generation config must have seen no modification since its creation (the hash is the same). + if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash( + self.generation_config + ): new_generation_config = GenerationConfig.from_model_config(self.config) if new_generation_config != self.generation_config: warnings.warn( "You have modified the pretrained model configuration to control generation. This is a" " deprecated strategy to control generation and will be removed soon, in a future version." - " Please use a generation configuration file (see" - " https://huggingface.co/docs/transformers/main_classes/text_generation )" + " Please use and modify the model generation configuration (see" + " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" ) self.generation_config = new_generation_config generation_config = self.generation_config diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index f983f527a8..7373ed6cb8 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2880,7 +2880,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi # Generation config max_length != 20 -> no warning with warnings.catch_warnings(record=True) as warning_list: + # generation_config is modified -> legacy mode is disabled = generation_config takes precedence model.generation_config.max_length = 10 - model.generation_config._from_model_config = False # otherwise model.config.max_length=20 takes precedence model.generate(input_ids) self.assertEqual(len(warning_list), 0)