From bbfb9fc22bdd49a45dd6ed850fc78c4d99b59afb Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 8 May 2023 10:45:40 +0100 Subject: [PATCH] =?UTF-8?q?Generate:=20starcoder=20=F0=9F=A4=9C=20?= =?UTF-8?q?=F0=9F=A4=9B=20assisted=20generation=20(#23182)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * starcoder has joined the chat * indexing that works for all --- src/transformers/generation/utils.py | 14 ++++++++++++-- tests/generation/test_utils.py | 4 ++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0f0191fb14..8c8a67fa5c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4221,6 +4221,9 @@ class GenerationMixin: # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + # other auxiliary variables + max_len = stopping_criteria[0].max_length + this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: @@ -4235,7 +4238,7 @@ class GenerationMixin: # Assistant: main logic start cur_len = input_ids.shape[-1] - max_len = stopping_criteria[0].max_length + assistant_kv_indexing = 0 if "bloom" not in assistant_model.__class__.__name__.lower() else 1 # 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we @@ -4244,7 +4247,7 @@ class GenerationMixin: for _ in range(int(assistant_model.max_assistant_tokens)): # 1.1. use the assistant model to obtain the next candidate logits if "assistant_past_key_values" in model_kwargs: - prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2] + prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2] # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) new_token_len = candidate_input_ids.shape[1] - prev_seq_len assist_inputs = candidate_input_ids[:, -new_token_len:] @@ -4505,6 +4508,13 @@ def _crop_past_key_values(model, past_key_values, maximum_length): ) ) past_key_values = tuple(new_past) + elif "gptbigcode" in model.__class__.__name__.lower(): # gptbigcode is too + if model.config.multi_query: + for idx in range(len(past_key_values)): + past_key_values[idx] = past_key_values[idx][:, :maximum_length, :] + else: + for idx in range(len(past_key_values)): + past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :] else: for idx in range(len(past_key_values)): new_past.append( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 3b96f2b2bd..70de057d5f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1473,7 +1473,7 @@ class GenerationTesterMixin: # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes if any( model_name in model_class.__name__.lower() - for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"] + for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] ): return @@ -1529,7 +1529,7 @@ class GenerationTesterMixin: # may fix in the future: the following models fail with assisted decoding, and need model-specific fixes if any( model_name in model_class.__name__.lower() - for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"] + for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"] ): return