Fix on "cache position" for assisted generation (#30068)
* clean commit history I hope * get kv seq length correctly * PR suggestions * Update src/transformers/testing_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add comment * give gpt bigcode it's own overriden method * remove code --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
31921d8d5e
commit
77b59dce9f
@@ -1091,8 +1091,9 @@ class GenerationTesterMixin:
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
|
||||
def test_assisted_decoding_matches_greedy_search(self):
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||
# NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
|
||||
# shape differences -- and it may result in a different output. The input shape difference happens in the
|
||||
@@ -1151,7 +1152,13 @@ class GenerationTesterMixin:
|
||||
}
|
||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
assistant_model = model
|
||||
# test with the same assistant model or randomly init one
|
||||
# in the first case all candidate tokens are accepted, in the second none is accepted
|
||||
# case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :)
|
||||
if assistant_type == "random":
|
||||
assistant_model = model_class(config).to(torch_device).eval()
|
||||
else:
|
||||
assistant_model = model
|
||||
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
||||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
||||
generation_kwargs.update({"assistant_model": assistant_model})
|
||||
|
||||
Reference in New Issue
Block a user