Generate: starcoder 🤜 🤛 assisted generation (#23182)
* starcoder has joined the chat * indexing that works for all
This commit is contained in:
@@ -4221,6 +4221,9 @@ class GenerationMixin:
|
|||||||
# keep track of which sequences are already finished
|
# keep track of which sequences are already finished
|
||||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||||
|
|
||||||
|
# other auxiliary variables
|
||||||
|
max_len = stopping_criteria[0].max_length
|
||||||
|
|
||||||
this_peer_finished = False # used by synced_gpus only
|
this_peer_finished = False # used by synced_gpus only
|
||||||
while True:
|
while True:
|
||||||
if synced_gpus:
|
if synced_gpus:
|
||||||
@@ -4235,7 +4238,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# Assistant: main logic start
|
# Assistant: main logic start
|
||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
max_len = stopping_criteria[0].max_length
|
assistant_kv_indexing = 0 if "bloom" not in assistant_model.__class__.__name__.lower() else 1
|
||||||
|
|
||||||
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
|
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
|
||||||
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
|
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
|
||||||
@@ -4244,7 +4247,7 @@ class GenerationMixin:
|
|||||||
for _ in range(int(assistant_model.max_assistant_tokens)):
|
for _ in range(int(assistant_model.max_assistant_tokens)):
|
||||||
# 1.1. use the assistant model to obtain the next candidate logits
|
# 1.1. use the assistant model to obtain the next candidate logits
|
||||||
if "assistant_past_key_values" in model_kwargs:
|
if "assistant_past_key_values" in model_kwargs:
|
||||||
prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2]
|
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
|
||||||
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
|
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
|
||||||
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
|
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
|
||||||
assist_inputs = candidate_input_ids[:, -new_token_len:]
|
assist_inputs = candidate_input_ids[:, -new_token_len:]
|
||||||
@@ -4505,6 +4508,13 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
past_key_values = tuple(new_past)
|
past_key_values = tuple(new_past)
|
||||||
|
elif "gptbigcode" in model.__class__.__name__.lower(): # gptbigcode is too
|
||||||
|
if model.config.multi_query:
|
||||||
|
for idx in range(len(past_key_values)):
|
||||||
|
past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]
|
||||||
|
else:
|
||||||
|
for idx in range(len(past_key_values)):
|
||||||
|
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
|
||||||
else:
|
else:
|
||||||
for idx in range(len(past_key_values)):
|
for idx in range(len(past_key_values)):
|
||||||
new_past.append(
|
new_past.append(
|
||||||
|
|||||||
@@ -1473,7 +1473,7 @@ class GenerationTesterMixin:
|
|||||||
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
|
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
|
||||||
if any(
|
if any(
|
||||||
model_name in model_class.__name__.lower()
|
model_name in model_class.__name__.lower()
|
||||||
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"]
|
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1529,7 +1529,7 @@ class GenerationTesterMixin:
|
|||||||
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
|
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
|
||||||
if any(
|
if any(
|
||||||
model_name in model_class.__name__.lower()
|
model_name in model_class.__name__.lower()
|
||||||
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"]
|
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user