Generate: legacy mode is only triggered when generation_config is untouched (#25962)
This commit is contained in:
@@ -55,12 +55,10 @@ When you load a model explicitly, you can inspect the generation configuration t
|
||||
>>> from transformers import AutoModelForCausalLM
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
||||
>>> model.generation_config # doctest: +IGNORE_RESULT
|
||||
>>> model.generation_config
|
||||
GenerationConfig {
|
||||
"_from_model_config": true,
|
||||
"bos_token_id": 50256,
|
||||
"eos_token_id": 50256,
|
||||
"transformers_version": "4.26.0.dev0"
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from ..utils import (
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
|
||||
|
||||
|
||||
class GenerationConfig(PushToHubMixin):
|
||||
@@ -315,20 +316,19 @@ class GenerationConfig(PushToHubMixin):
|
||||
# Validate the values of the attributes
|
||||
self.validate(is_init=True)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.to_json_string(ignore_metadata=True))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, GenerationConfig):
|
||||
return False
|
||||
|
||||
self_dict = self.__dict__.copy()
|
||||
other_dict = other.__dict__.copy()
|
||||
# ignore metadata
|
||||
for metadata_field in ("_from_model_config", "_commit_hash", "transformers_version"):
|
||||
self_dict.pop(metadata_field, None)
|
||||
other_dict.pop(metadata_field, None)
|
||||
return self_dict == other_dict
|
||||
self_without_metadata = self.to_json_string(use_diff=False, ignore_metadata=True)
|
||||
other_without_metadata = other.to_json_string(use_diff=False, ignore_metadata=True)
|
||||
return self_without_metadata == other_without_metadata
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
return f"{self.__class__.__name__} {self.to_json_string(ignore_metadata=True)}"
|
||||
|
||||
def validate(self, is_init=False):
|
||||
"""
|
||||
@@ -729,7 +729,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
else:
|
||||
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
config = cls.from_dict(config_dict, **kwargs)
|
||||
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
||||
@@ -814,8 +816,12 @@ class GenerationConfig(PushToHubMixin):
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
|
||||
# Fields to ignore at serialization time
|
||||
if "_commit_hash" in output:
|
||||
del output["_commit_hash"]
|
||||
if "_original_object_hash" in output:
|
||||
del output["_original_object_hash"]
|
||||
|
||||
# Transformers version when serializing this file
|
||||
output["transformers_version"] = __version__
|
||||
@@ -823,7 +829,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.dict_torch_dtype_to_str(output)
|
||||
return output
|
||||
|
||||
def to_json_string(self, use_diff: bool = True) -> str:
|
||||
def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -> str:
|
||||
"""
|
||||
Serializes this instance to a JSON string.
|
||||
|
||||
@@ -831,6 +837,8 @@ class GenerationConfig(PushToHubMixin):
|
||||
use_diff (`bool`, *optional*, defaults to `True`):
|
||||
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
|
||||
is serialized to JSON string.
|
||||
ignore_metadata (`bool`, *optional*, defaults to `False`):
|
||||
Whether to ignore the metadata fields present in the instance
|
||||
|
||||
Returns:
|
||||
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
||||
@@ -839,6 +847,11 @@ class GenerationConfig(PushToHubMixin):
|
||||
config_dict = self.to_diff_dict()
|
||||
else:
|
||||
config_dict = self.to_dict()
|
||||
|
||||
if ignore_metadata:
|
||||
for metadata_field in METADATA_FIELDS:
|
||||
config_dict.pop(metadata_field, None)
|
||||
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
|
||||
@@ -882,6 +895,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
|
||||
setattr(config, attr, decoder_config[attr])
|
||||
|
||||
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
|
||||
return config
|
||||
|
||||
def update(self, **kwargs):
|
||||
|
||||
@@ -310,16 +310,20 @@ class FlaxGenerationMixin:
|
||||
|
||||
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
||||
if generation_config is None:
|
||||
# legacy: users may modify the model configuration to control generation -- update the generation config
|
||||
# model attribute accordingly, if it was created from the model config
|
||||
if self.generation_config._from_model_config:
|
||||
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
|
||||
# two conditions must be met
|
||||
# 1) the generation config must have been created from the model config (`_from_model_config` field);
|
||||
# 2) the generation config must have seen no modification since its creation (the hash is the same).
|
||||
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
|
||||
self.generation_config
|
||||
):
|
||||
new_generation_config = GenerationConfig.from_model_config(self.config)
|
||||
if new_generation_config != self.generation_config:
|
||||
warnings.warn(
|
||||
"You have modified the pretrained model configuration to control generation. This is a"
|
||||
" deprecated strategy to control generation and will be removed soon, in a future version."
|
||||
" Please use a generation configuration file (see"
|
||||
" https://huggingface.co/docs/transformers/main_classes/text_generation )"
|
||||
" Please use and modify the model generation configuration (see"
|
||||
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
|
||||
)
|
||||
self.generation_config = new_generation_config
|
||||
generation_config = self.generation_config
|
||||
|
||||
@@ -716,16 +716,20 @@ class TFGenerationMixin:
|
||||
|
||||
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
||||
if generation_config is None:
|
||||
# legacy: users may modify the model configuration to control generation -- update the generation config
|
||||
# model attribute accordingly, if it was created from the model config
|
||||
if self.generation_config._from_model_config:
|
||||
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
|
||||
# two conditions must be met
|
||||
# 1) the generation config must have been created from the model config (`_from_model_config` field);
|
||||
# 2) the generation config must have seen no modification since its creation (the hash is the same).
|
||||
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
|
||||
self.generation_config
|
||||
):
|
||||
new_generation_config = GenerationConfig.from_model_config(self.config)
|
||||
if new_generation_config != self.generation_config:
|
||||
warnings.warn(
|
||||
"You have modified the pretrained model configuration to control generation. This is a"
|
||||
" deprecated strategy to control generation and will be removed soon, in a future version."
|
||||
" Please use a generation configuration file (see"
|
||||
" https://huggingface.co/docs/transformers/main_classes/text_generation )"
|
||||
" Please use and modify the model generation configuration (see"
|
||||
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
|
||||
)
|
||||
self.generation_config = new_generation_config
|
||||
generation_config = self.generation_config
|
||||
|
||||
@@ -1409,16 +1409,20 @@ class GenerationMixin:
|
||||
|
||||
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
||||
if generation_config is None:
|
||||
# legacy: users may modify the model configuration to control generation -- update the generation config
|
||||
# model attribute accordingly, if it was created from the model config
|
||||
if self.generation_config._from_model_config:
|
||||
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
|
||||
# two conditions must be met
|
||||
# 1) the generation config must have been created from the model config (`_from_model_config` field);
|
||||
# 2) the generation config must have seen no modification since its creation (the hash is the same).
|
||||
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
|
||||
self.generation_config
|
||||
):
|
||||
new_generation_config = GenerationConfig.from_model_config(self.config)
|
||||
if new_generation_config != self.generation_config:
|
||||
warnings.warn(
|
||||
"You have modified the pretrained model configuration to control generation. This is a"
|
||||
" deprecated strategy to control generation and will be removed soon, in a future version."
|
||||
" Please use a generation configuration file (see"
|
||||
" https://huggingface.co/docs/transformers/main_classes/text_generation )"
|
||||
" Please use and modify the model generation configuration (see"
|
||||
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
|
||||
)
|
||||
self.generation_config = new_generation_config
|
||||
generation_config = self.generation_config
|
||||
|
||||
@@ -2880,7 +2880,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
# Generation config max_length != 20 -> no warning
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
# generation_config is modified -> legacy mode is disabled = generation_config takes precedence
|
||||
model.generation_config.max_length = 10
|
||||
model.generation_config._from_model_config = False # otherwise model.config.max_length=20 takes precedence
|
||||
model.generate(input_ids)
|
||||
self.assertEqual(len(warning_list), 0)
|
||||
|
||||
Reference in New Issue
Block a user