Generate: legacy mode is only triggered when generation_config is untouched (#25962)
This commit is contained in:
@@ -55,12 +55,10 @@ When you load a model explicitly, you can inspect the generation configuration t
|
|||||||
>>> from transformers import AutoModelForCausalLM
|
>>> from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
||||||
>>> model.generation_config # doctest: +IGNORE_RESULT
|
>>> model.generation_config
|
||||||
GenerationConfig {
|
GenerationConfig {
|
||||||
"_from_model_config": true,
|
|
||||||
"bos_token_id": 50256,
|
"bos_token_id": 50256,
|
||||||
"eos_token_id": 50256,
|
"eos_token_id": 50256,
|
||||||
"transformers_version": "4.26.0.dev0"
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from ..utils import (
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
|
||||||
|
|
||||||
|
|
||||||
class GenerationConfig(PushToHubMixin):
|
class GenerationConfig(PushToHubMixin):
|
||||||
@@ -315,20 +316,19 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
# Validate the values of the attributes
|
# Validate the values of the attributes
|
||||||
self.validate(is_init=True)
|
self.validate(is_init=True)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.to_json_string(ignore_metadata=True))
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if not isinstance(other, GenerationConfig):
|
if not isinstance(other, GenerationConfig):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self_dict = self.__dict__.copy()
|
self_without_metadata = self.to_json_string(use_diff=False, ignore_metadata=True)
|
||||||
other_dict = other.__dict__.copy()
|
other_without_metadata = other.to_json_string(use_diff=False, ignore_metadata=True)
|
||||||
# ignore metadata
|
return self_without_metadata == other_without_metadata
|
||||||
for metadata_field in ("_from_model_config", "_commit_hash", "transformers_version"):
|
|
||||||
self_dict.pop(metadata_field, None)
|
|
||||||
other_dict.pop(metadata_field, None)
|
|
||||||
return self_dict == other_dict
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
return f"{self.__class__.__name__} {self.to_json_string(ignore_metadata=True)}"
|
||||||
|
|
||||||
def validate(self, is_init=False):
|
def validate(self, is_init=False):
|
||||||
"""
|
"""
|
||||||
@@ -729,7 +729,9 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
else:
|
else:
|
||||||
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
|
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
|
||||||
|
|
||||||
return cls.from_dict(config_dict, **kwargs)
|
config = cls.from_dict(config_dict, **kwargs)
|
||||||
|
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
|
||||||
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
||||||
@@ -814,8 +816,12 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||||
"""
|
"""
|
||||||
output = copy.deepcopy(self.__dict__)
|
output = copy.deepcopy(self.__dict__)
|
||||||
|
|
||||||
|
# Fields to ignore at serialization time
|
||||||
if "_commit_hash" in output:
|
if "_commit_hash" in output:
|
||||||
del output["_commit_hash"]
|
del output["_commit_hash"]
|
||||||
|
if "_original_object_hash" in output:
|
||||||
|
del output["_original_object_hash"]
|
||||||
|
|
||||||
# Transformers version when serializing this file
|
# Transformers version when serializing this file
|
||||||
output["transformers_version"] = __version__
|
output["transformers_version"] = __version__
|
||||||
@@ -823,7 +829,7 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self.dict_torch_dtype_to_str(output)
|
self.dict_torch_dtype_to_str(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def to_json_string(self, use_diff: bool = True) -> str:
|
def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -> str:
|
||||||
"""
|
"""
|
||||||
Serializes this instance to a JSON string.
|
Serializes this instance to a JSON string.
|
||||||
|
|
||||||
@@ -831,6 +837,8 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
use_diff (`bool`, *optional*, defaults to `True`):
|
use_diff (`bool`, *optional*, defaults to `True`):
|
||||||
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
|
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
|
||||||
is serialized to JSON string.
|
is serialized to JSON string.
|
||||||
|
ignore_metadata (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to ignore the metadata fields present in the instance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
||||||
@@ -839,6 +847,11 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
config_dict = self.to_diff_dict()
|
config_dict = self.to_diff_dict()
|
||||||
else:
|
else:
|
||||||
config_dict = self.to_dict()
|
config_dict = self.to_dict()
|
||||||
|
|
||||||
|
if ignore_metadata:
|
||||||
|
for metadata_field in METADATA_FIELDS:
|
||||||
|
config_dict.pop(metadata_field, None)
|
||||||
|
|
||||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
|
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
|
||||||
@@ -882,6 +895,7 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
|
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
|
||||||
setattr(config, attr, decoder_config[attr])
|
setattr(config, attr, decoder_config[attr])
|
||||||
|
|
||||||
|
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def update(self, **kwargs):
|
def update(self, **kwargs):
|
||||||
|
|||||||
@@ -310,16 +310,20 @@ class FlaxGenerationMixin:
|
|||||||
|
|
||||||
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
||||||
if generation_config is None:
|
if generation_config is None:
|
||||||
# legacy: users may modify the model configuration to control generation -- update the generation config
|
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
|
||||||
# model attribute accordingly, if it was created from the model config
|
# two conditions must be met
|
||||||
if self.generation_config._from_model_config:
|
# 1) the generation config must have been created from the model config (`_from_model_config` field);
|
||||||
|
# 2) the generation config must have seen no modification since its creation (the hash is the same).
|
||||||
|
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
|
||||||
|
self.generation_config
|
||||||
|
):
|
||||||
new_generation_config = GenerationConfig.from_model_config(self.config)
|
new_generation_config = GenerationConfig.from_model_config(self.config)
|
||||||
if new_generation_config != self.generation_config:
|
if new_generation_config != self.generation_config:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"You have modified the pretrained model configuration to control generation. This is a"
|
"You have modified the pretrained model configuration to control generation. This is a"
|
||||||
" deprecated strategy to control generation and will be removed soon, in a future version."
|
" deprecated strategy to control generation and will be removed soon, in a future version."
|
||||||
" Please use a generation configuration file (see"
|
" Please use and modify the model generation configuration (see"
|
||||||
" https://huggingface.co/docs/transformers/main_classes/text_generation )"
|
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
|
||||||
)
|
)
|
||||||
self.generation_config = new_generation_config
|
self.generation_config = new_generation_config
|
||||||
generation_config = self.generation_config
|
generation_config = self.generation_config
|
||||||
|
|||||||
@@ -716,16 +716,20 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
||||||
if generation_config is None:
|
if generation_config is None:
|
||||||
# legacy: users may modify the model configuration to control generation -- update the generation config
|
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
|
||||||
# model attribute accordingly, if it was created from the model config
|
# two conditions must be met
|
||||||
if self.generation_config._from_model_config:
|
# 1) the generation config must have been created from the model config (`_from_model_config` field);
|
||||||
|
# 2) the generation config must have seen no modification since its creation (the hash is the same).
|
||||||
|
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
|
||||||
|
self.generation_config
|
||||||
|
):
|
||||||
new_generation_config = GenerationConfig.from_model_config(self.config)
|
new_generation_config = GenerationConfig.from_model_config(self.config)
|
||||||
if new_generation_config != self.generation_config:
|
if new_generation_config != self.generation_config:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"You have modified the pretrained model configuration to control generation. This is a"
|
"You have modified the pretrained model configuration to control generation. This is a"
|
||||||
" deprecated strategy to control generation and will be removed soon, in a future version."
|
" deprecated strategy to control generation and will be removed soon, in a future version."
|
||||||
" Please use a generation configuration file (see"
|
" Please use and modify the model generation configuration (see"
|
||||||
" https://huggingface.co/docs/transformers/main_classes/text_generation )"
|
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
|
||||||
)
|
)
|
||||||
self.generation_config = new_generation_config
|
self.generation_config = new_generation_config
|
||||||
generation_config = self.generation_config
|
generation_config = self.generation_config
|
||||||
|
|||||||
@@ -1409,16 +1409,20 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
||||||
if generation_config is None:
|
if generation_config is None:
|
||||||
# legacy: users may modify the model configuration to control generation -- update the generation config
|
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
|
||||||
# model attribute accordingly, if it was created from the model config
|
# two conditions must be met
|
||||||
if self.generation_config._from_model_config:
|
# 1) the generation config must have been created from the model config (`_from_model_config` field);
|
||||||
|
# 2) the generation config must have seen no modification since its creation (the hash is the same).
|
||||||
|
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
|
||||||
|
self.generation_config
|
||||||
|
):
|
||||||
new_generation_config = GenerationConfig.from_model_config(self.config)
|
new_generation_config = GenerationConfig.from_model_config(self.config)
|
||||||
if new_generation_config != self.generation_config:
|
if new_generation_config != self.generation_config:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"You have modified the pretrained model configuration to control generation. This is a"
|
"You have modified the pretrained model configuration to control generation. This is a"
|
||||||
" deprecated strategy to control generation and will be removed soon, in a future version."
|
" deprecated strategy to control generation and will be removed soon, in a future version."
|
||||||
" Please use a generation configuration file (see"
|
" Please use and modify the model generation configuration (see"
|
||||||
" https://huggingface.co/docs/transformers/main_classes/text_generation )"
|
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
|
||||||
)
|
)
|
||||||
self.generation_config = new_generation_config
|
self.generation_config = new_generation_config
|
||||||
generation_config = self.generation_config
|
generation_config = self.generation_config
|
||||||
|
|||||||
@@ -2880,7 +2880,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
|
|
||||||
# Generation config max_length != 20 -> no warning
|
# Generation config max_length != 20 -> no warning
|
||||||
with warnings.catch_warnings(record=True) as warning_list:
|
with warnings.catch_warnings(record=True) as warning_list:
|
||||||
|
# generation_config is modified -> legacy mode is disabled = generation_config takes precedence
|
||||||
model.generation_config.max_length = 10
|
model.generation_config.max_length = 10
|
||||||
model.generation_config._from_model_config = False # otherwise model.config.max_length=20 takes precedence
|
|
||||||
model.generate(input_ids)
|
model.generate(input_ids)
|
||||||
self.assertEqual(len(warning_list), 0)
|
self.assertEqual(len(warning_list), 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user