[generate] revert change in Aria: the maximum cache length must match max_length (#36120)
* revert inputs_embeds len * Update test_utils.py * make fixup
This commit is contained in:
@@ -1470,7 +1470,6 @@ class GenerationMixin:
|
|||||||
elif (
|
elif (
|
||||||
model_input_name == "inputs_embeds"
|
model_input_name == "inputs_embeds"
|
||||||
and input_ids_length != inputs_tensor.shape[1]
|
and input_ids_length != inputs_tensor.shape[1]
|
||||||
and input_ids_length != 0
|
|
||||||
and not self.config.is_encoder_decoder
|
and not self.config.is_encoder_decoder
|
||||||
):
|
):
|
||||||
generation_config.max_length -= inputs_tensor.shape[1]
|
generation_config.max_length -= inputs_tensor.shape[1]
|
||||||
|
|||||||
@@ -1786,12 +1786,12 @@ class GenerationTesterMixin:
|
|||||||
model.config.use_cache = True
|
model.config.use_cache = True
|
||||||
model.config.is_decoder = True
|
model.config.is_decoder = True
|
||||||
batch_size = input_ids.shape[0]
|
batch_size = input_ids.shape[0]
|
||||||
max_cache_len = 30
|
max_length = 30
|
||||||
|
|
||||||
# here we force to not stop at eos and go until max-length
|
# here we force to not stop at eos and go until max-length
|
||||||
model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1
|
model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
"max_length": max_cache_len,
|
"max_length": max_length,
|
||||||
"cache_implementation": "static",
|
"cache_implementation": "static",
|
||||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||||
}
|
}
|
||||||
@@ -1810,11 +1810,11 @@ class GenerationTesterMixin:
|
|||||||
num_hidden_layers = text_config.num_hidden_layers
|
num_hidden_layers = text_config.num_hidden_layers
|
||||||
|
|
||||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||||
max_cache_len += inputs_embeds.shape[1] - 1 # the last generated token has no cache
|
|
||||||
outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict)
|
outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict)
|
||||||
|
|
||||||
# we should get `max_length` in shape, not `max_length - embeds_length`
|
# we should get `max_length - 1` in shape, not `max_length - embeds_length`.
|
||||||
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
# -1 because the last generated token isn't yet in the cache.
|
||||||
|
cache_shape = (batch_size, num_key_value_heads, max_length - 1, head_dim)
|
||||||
self.assertTrue(isinstance(outputs.past_key_values, StaticCache))
|
self.assertTrue(isinstance(outputs.past_key_values, StaticCache))
|
||||||
self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers)
|
self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers)
|
||||||
self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)
|
self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)
|
||||||
|
|||||||
Reference in New Issue
Block a user