Generate: assisted generation with sample (take 2) (#22949)
* temperature controls speed
This commit is contained in:
@@ -333,15 +333,16 @@ This guide illustrates the main parameters that enable various decoding strategi
|
|||||||
[`generate`] method, which gives you even further control over the [`generate`] method's behavior.
|
[`generate`] method, which gives you even further control over the [`generate`] method's behavior.
|
||||||
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx).
|
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx).
|
||||||
|
|
||||||
### Assisted Generation
|
### Assisted Decoding
|
||||||
|
|
||||||
Assisted generation is a modification of the decoding strategies above that uses an assistant model with the same
|
Assisted decoding is a modification of the decoding strategies above that uses an assistant model with the same
|
||||||
tokenizer (ideally a much smaller model) to speed up the decoding process. Currently only assisted greedy search is
|
tokenizer (ideally a much smaller model) to greedily generate a few candidate tokens. The main model then validates
|
||||||
supported, and doesn't support batched inputs.
|
the candidate tokens in a single forward pass, which speeds up the decoding process. Currently, only greedy search
|
||||||
|
and sampling are supported with assisted decoding, and doesn't support batched inputs.
|
||||||
|
|
||||||
<!-- TODO: add link to the blog post about assisted generation when it exists -->
|
<!-- TODO: add link to the blog post about assisted decoding when it exists -->
|
||||||
|
|
||||||
To enable assisted generation, set the `assistant_model` argument with a model.
|
To enable assisted decoding, set the `assistant_model` argument with a model.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
@@ -359,3 +360,25 @@ To enable assisted generation, set the `assistant_model` argument with a model.
|
|||||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||||
```
|
```
|
||||||
|
|
||||||
|
When using assisted decoding with sampling methods, you can use the `temperarure` argument to control the randomness
|
||||||
|
just like in multinomial sampling. However, in assisted decoding, reducing the temperature will help improving latency.
|
||||||
|
|
||||||
|
<!-- TODO: link the blog post again to explain why the tradeoff exists -->
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
>>> prompt = "Alice and Bob"
|
||||||
|
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
|
||||||
|
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
|
||||||
|
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
||||||
|
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
|
||||||
|
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5)
|
||||||
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
["Alice and Bob are sitting on the sofa. Alice says, 'I'm going to my room"]
|
||||||
|
```
|
||||||
|
|||||||
@@ -54,8 +54,10 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
`num_beams>1` and `num_beam_groups>1`
|
`num_beams>1` and `num_beam_groups>1`
|
||||||
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
|
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
|
||||||
`constraints!=None` or `force_words_ids!=None`
|
`constraints!=None` or `force_words_ids!=None`
|
||||||
|
- *assisted decoding* by calling [`~generation.GenerationMixin.assisted_decoding`], if
|
||||||
|
`assistant_model` is passed to `.generate()`
|
||||||
|
|
||||||
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate'. To learn
|
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
|
||||||
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
||||||
|
|
||||||
Arg:
|
Arg:
|
||||||
|
|||||||
@@ -492,7 +492,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
def prepare_inputs_for_generation(self, *args, **kwargs):
|
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
|
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prepare_model_inputs(
|
def _prepare_model_inputs(
|
||||||
@@ -962,10 +962,10 @@ class GenerationMixin:
|
|||||||
object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
|
object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
|
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
|
||||||
f" `generate`, but it has already been created with the values {default}. {default} has been"
|
f" `.generate()`, but it has already been created with the values {default}. {default} has been"
|
||||||
" created by passing the corresponding arguments to generate or by the model's config default"
|
" created by passing the corresponding arguments to generate or by the model's config default"
|
||||||
f" values. If you just want to change the default values of {object_type} consider passing"
|
f" values. If you just want to change the default values of {object_type} consider passing"
|
||||||
f" them as arguments to `generate` instead of using a custom {object_type}."
|
f" them as arguments to `.generate()` instead of using a custom {object_type}."
|
||||||
)
|
)
|
||||||
default_list.extend(custom_list)
|
default_list.extend(custom_list)
|
||||||
return default_list
|
return default_list
|
||||||
@@ -1418,14 +1418,14 @@ class GenerationMixin:
|
|||||||
and not is_constraint_gen_mode
|
and not is_constraint_gen_mode
|
||||||
and not is_contrastive_search_gen_mode
|
and not is_contrastive_search_gen_mode
|
||||||
)
|
)
|
||||||
is_assisted_greedy_gen_mode = False
|
is_assisted_gen_mode = False
|
||||||
if assistant_model is not None:
|
if assistant_model is not None:
|
||||||
if not is_greedy_gen_mode:
|
if not (is_greedy_gen_mode or is_sample_gen_mode):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You've set `assistant_model`, which triggers assisted generation. Currently, assisted generation "
|
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
|
||||||
"is only supported with Greedy Search."
|
"is only supported with Greedy Search and Sample."
|
||||||
)
|
)
|
||||||
is_assisted_greedy_gen_mode = True
|
is_assisted_gen_mode = True
|
||||||
|
|
||||||
if generation_config.num_beam_groups > generation_config.num_beams:
|
if generation_config.num_beam_groups > generation_config.num_beams:
|
||||||
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
||||||
@@ -1464,16 +1464,16 @@ class GenerationMixin:
|
|||||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||||
)
|
)
|
||||||
# 10. go into different generation modes
|
# 10. go into different generation modes
|
||||||
if is_assisted_greedy_gen_mode:
|
if is_assisted_gen_mode:
|
||||||
if generation_config.num_return_sequences > 1:
|
if generation_config.num_return_sequences > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"num_return_sequences has to be 1 when doing assisted greedy search, "
|
"num_return_sequences has to be 1 when doing assisted generate, "
|
||||||
f"but is {generation_config.num_return_sequences}."
|
f"but is {generation_config.num_return_sequences}."
|
||||||
)
|
)
|
||||||
if batch_size > 1:
|
if batch_size > 1:
|
||||||
raise ValueError("Assisted generation is only supported for batch_size = 1")
|
raise ValueError("assisted generate is only supported for batch_size = 1")
|
||||||
if not model_kwargs["use_cache"]:
|
if not model_kwargs["use_cache"]:
|
||||||
raise ValueError("Assisted generation requires `use_cache=True`")
|
raise ValueError("assisted generate requires `use_cache=True`")
|
||||||
|
|
||||||
# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
|
# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
|
||||||
if assistant_model.config.is_encoder_decoder:
|
if assistant_model.config.is_encoder_decoder:
|
||||||
@@ -1486,11 +1486,13 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"]
|
model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"]
|
||||||
|
|
||||||
# 12. run assisted greedy search
|
# 12. run assisted generate
|
||||||
return self.assisted_greedy_search(
|
return self.assisted_decoding(
|
||||||
input_ids,
|
input_ids,
|
||||||
assistant_model=assistant_model,
|
assistant_model=assistant_model,
|
||||||
|
do_sample=generation_config.do_sample,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
|
logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
@@ -4059,11 +4061,13 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
return sequence_outputs["sequences"]
|
return sequence_outputs["sequences"]
|
||||||
|
|
||||||
def assisted_greedy_search(
|
def assisted_decoding(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
assistant_model: "PreTrainedModel",
|
assistant_model: "PreTrainedModel",
|
||||||
|
do_sample: bool = False,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
|
logits_warper: Optional[LogitsProcessorList] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
@@ -4076,12 +4080,13 @@ class GenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Generates sequences of token ids for models with a language modeling head using **greedy decoding**, assisted
|
Generates sequences of token ids for models with a language modeling head using **greedy decoding** or
|
||||||
by a smaller model. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
**sample** (depending on `do_sample`), assisted by a smaller model. Can be used for text-decoder, text-to-text,
|
||||||
|
speech-to-text, and vision-to-text models.
|
||||||
|
|
||||||
<Tip warning={true}>
|
<Tip warning={true}>
|
||||||
|
|
||||||
In most cases, you do not need to call [`~generation.GenerationMixin.assisted_greedy_search`] directly. Use
|
In most cases, you do not need to call [`~generation.GenerationMixin.assisted_decoding`] directly. Use
|
||||||
generate() instead. For an overview of generation strategies and code examples, check the [following
|
generate() instead. For an overview of generation strategies and code examples, check the [following
|
||||||
guide](../generation_strategies).
|
guide](../generation_strategies).
|
||||||
|
|
||||||
@@ -4095,9 +4100,15 @@ class GenerationMixin:
|
|||||||
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
|
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
|
||||||
is much faster than running generation with the model you're calling generate from. As such, the
|
is much faster than running generation with the model you're calling generate from. As such, the
|
||||||
assistant model should be much smaller.
|
assistant model should be much smaller.
|
||||||
|
do_sample (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to use sampling ; use greedy decoding otherwise.
|
||||||
logits_processor (`LogitsProcessorList`, *optional*):
|
logits_processor (`LogitsProcessorList`, *optional*):
|
||||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||||
|
logits_warper (`LogitsProcessorList`, *optional*):
|
||||||
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
||||||
|
to warp the prediction score distribution of the language modeling head applied before multinomial
|
||||||
|
sampling at each generation step.
|
||||||
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
||||||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
||||||
used to tell if the generation loop should stop.
|
used to tell if the generation loop should stop.
|
||||||
@@ -4157,7 +4168,7 @@ class GenerationMixin:
|
|||||||
... ]
|
... ]
|
||||||
... )
|
... )
|
||||||
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
||||||
>>> outputs = model.assisted_greedy_search(
|
>>> outputs = model.assisted_decoding(
|
||||||
... input_ids,
|
... input_ids,
|
||||||
... assistant_model=assistant_model,
|
... assistant_model=assistant_model,
|
||||||
... logits_processor=logits_processor,
|
... logits_processor=logits_processor,
|
||||||
@@ -4166,13 +4177,14 @@ class GenerationMixin:
|
|||||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
|
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
|
||||||
```"""
|
```"""
|
||||||
# NOTE: the code here is copy/paste from greedy search, except when clearly stated in the comments
|
# NOTE: the code here is copy/paste from greedy search/sample, except when clearly stated in the comments
|
||||||
# Assistant: initialize assistant-related variables
|
# Assistant: initialize assistant-related variables
|
||||||
if not hasattr(assistant_model, "max_assistant_tokens"):
|
if not hasattr(assistant_model, "max_assistant_tokens"):
|
||||||
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
|
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
|
||||||
|
|
||||||
# init values
|
# init values
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
@@ -4285,6 +4297,8 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
|
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
|
||||||
# `candidate_length + 1` relevant logits from this process (see step 7 on why the +1)
|
# `candidate_length + 1` relevant logits from this process (see step 7 on why the +1)
|
||||||
|
|
||||||
|
# 2.1. Run a forward pass on the candidate sequence
|
||||||
if "past_key_values" in model_kwargs:
|
if "past_key_values" in model_kwargs:
|
||||||
og_model_attn = torch.ones_like(candidate_input_ids)
|
og_model_attn = torch.ones_like(candidate_input_ids)
|
||||||
og_model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
|
og_model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
|
||||||
@@ -4320,17 +4334,28 @@ class GenerationMixin:
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Obtain the argmax from the original model logits.
|
# 2.2. Process the new logits
|
||||||
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
|
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
|
||||||
if len(logits_processor) > 0:
|
if len(logits_processor) > 0:
|
||||||
for i in range(candidate_length):
|
for i in range(candidate_length):
|
||||||
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
||||||
max_logits = new_logits.argmax(dim=-1)[:, -candidate_length - 1 : -1]
|
if len(logits_warper) > 0:
|
||||||
|
for i in range(candidate_length):
|
||||||
|
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
||||||
|
|
||||||
|
# 3. Obtain the next tokens from the original model logits. If `do_sample` is True, use multinomial
|
||||||
|
# sampling, otherwise use argmax.
|
||||||
|
if do_sample:
|
||||||
|
probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1)
|
||||||
|
sampled_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
|
||||||
|
next_tokens = sampled_tokens[:, :-1]
|
||||||
|
else:
|
||||||
|
next_tokens = new_logits[:, -candidate_length - 1 : -1, :].argmax(dim=-1)
|
||||||
|
|
||||||
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
|
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
|
||||||
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
|
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
|
||||||
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
|
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
|
||||||
n_matches = ((~(candidate_new_tokens == max_logits)).cumsum(dim=-1) < 1).sum()
|
n_matches = ((~(candidate_new_tokens == next_tokens)).cumsum(dim=-1) < 1).sum()
|
||||||
|
|
||||||
# 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
# 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
||||||
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
|
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
|
||||||
@@ -4360,12 +4385,17 @@ class GenerationMixin:
|
|||||||
next_token_scores = new_logits[:, n_matches, :]
|
next_token_scores = new_logits[:, n_matches, :]
|
||||||
|
|
||||||
# 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that,
|
# 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that,
|
||||||
# because of this step, assisted greedy search reduces to a normal greedy search if there is no match.
|
# because of this step, assisted generation search reduces to a normal greedy search/sample if there is no
|
||||||
|
# match.
|
||||||
|
if do_sample:
|
||||||
|
probs = probs[:, n_matches, :]
|
||||||
|
next_tokens = sampled_tokens[:, n_matches]
|
||||||
|
else:
|
||||||
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
||||||
|
|
||||||
# Assistant: main logic end; Compared to greedy search, the following (redundant) blocks were removed
|
# Assistant: main logic end; Compared to greedy search/sample, the following (redundant) blocks were
|
||||||
# below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model cache
|
# removed below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model
|
||||||
# update.
|
# cache update.
|
||||||
|
|
||||||
if synced_gpus and this_peer_finished:
|
if synced_gpus and this_peer_finished:
|
||||||
continue # don't waste resources running the code we don't need
|
continue # don't waste resources running the code we don't need
|
||||||
@@ -4378,20 +4408,18 @@ class GenerationMixin:
|
|||||||
|
|
||||||
if "past_key_values" not in model_kwargs:
|
if "past_key_values" not in model_kwargs:
|
||||||
last_matching_idx = new_cur_len - 1
|
last_matching_idx = new_cur_len - 1
|
||||||
prompt_length = cur_len
|
|
||||||
else:
|
else:
|
||||||
last_matching_idx = n_matches
|
last_matching_idx = n_matches
|
||||||
prompt_length = 0
|
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
cross_attentions = _split_model_outputs(
|
cross_attentions = _split_model_outputs(
|
||||||
cross_attentions, outputs.cross_attentions, prompt_length, last_matching_idx
|
cross_attentions, outputs.cross_attentions, cur_len, last_matching_idx
|
||||||
)
|
)
|
||||||
decoder_attentions = _split_model_outputs(
|
decoder_attentions = _split_model_outputs(
|
||||||
decoder_attentions,
|
decoder_attentions,
|
||||||
outputs.decoder_attentions,
|
outputs.decoder_attentions,
|
||||||
prompt_length,
|
cur_len,
|
||||||
last_matching_idx,
|
last_matching_idx,
|
||||||
is_decoder_attention=True,
|
is_decoder_attention=True,
|
||||||
)
|
)
|
||||||
@@ -4399,18 +4427,18 @@ class GenerationMixin:
|
|||||||
decoder_attentions = _split_model_outputs(
|
decoder_attentions = _split_model_outputs(
|
||||||
decoder_attentions,
|
decoder_attentions,
|
||||||
outputs.attentions,
|
outputs.attentions,
|
||||||
prompt_length,
|
cur_len,
|
||||||
last_matching_idx,
|
last_matching_idx,
|
||||||
is_decoder_attention=True,
|
is_decoder_attention=True,
|
||||||
)
|
)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
decoder_hidden_states = _split_model_outputs(
|
decoder_hidden_states = _split_model_outputs(
|
||||||
decoder_hidden_states, outputs.decoder_hidden_states, prompt_length, last_matching_idx
|
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, last_matching_idx
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
decoder_hidden_states = _split_model_outputs(
|
decoder_hidden_states = _split_model_outputs(
|
||||||
decoder_hidden_states, outputs.hidden_states, prompt_length, last_matching_idx
|
decoder_hidden_states, outputs.hidden_states, cur_len, last_matching_idx
|
||||||
)
|
)
|
||||||
|
|
||||||
# finished sentences should have their next token be a padding token
|
# finished sentences should have their next token be a padding token
|
||||||
@@ -4503,24 +4531,26 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
|
|||||||
return past_key_values
|
return past_key_values
|
||||||
|
|
||||||
|
|
||||||
def _split_model_outputs(outputs, new_outputs, prompt_length, last_matching_idx, is_decoder_attention=False):
|
def _split_model_outputs(outputs, new_outputs, previous_cur_len, last_matching_idx, is_decoder_attention=False):
|
||||||
"""
|
"""
|
||||||
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
|
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
|
||||||
where each member corresponds to a single generated token.
|
where each member corresponds to a single generated token.
|
||||||
"""
|
"""
|
||||||
# Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the
|
# Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the
|
||||||
# prompt.
|
# prompt.
|
||||||
if prompt_length > 0:
|
if len(outputs) == 0:
|
||||||
new_tuple = ()
|
new_tuple = ()
|
||||||
for layer in new_outputs:
|
for layer in new_outputs:
|
||||||
last_dim_size = prompt_length if is_decoder_attention else layer.shape[-1]
|
last_dim_size = previous_cur_len if is_decoder_attention else layer.shape[-1]
|
||||||
new_tuple += (layer[..., :prompt_length, :last_dim_size],)
|
new_tuple += (layer[..., :previous_cur_len, :last_dim_size],)
|
||||||
outputs += (new_tuple,)
|
outputs += (new_tuple,)
|
||||||
|
last_matching_idx -= previous_cur_len
|
||||||
|
previous_cur_len += 1
|
||||||
|
|
||||||
for i in range(prompt_length, last_matching_idx + 1):
|
for i in range(last_matching_idx + 1):
|
||||||
new_tuple = ()
|
new_tuple = ()
|
||||||
for layer in new_outputs:
|
for layer in new_outputs:
|
||||||
last_dim_size = i + 1 if is_decoder_attention else layer.shape[-1]
|
last_dim_size = previous_cur_len + i if is_decoder_attention else layer.shape[-1]
|
||||||
new_tuple += (layer[..., i : i + 1, :last_dim_size],)
|
new_tuple += (layer[..., i : i + 1, :last_dim_size],)
|
||||||
outputs += (new_tuple,)
|
outputs += (new_tuple,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@@ -1457,22 +1457,22 @@ class GenerationTesterMixin:
|
|||||||
for output in (output_contrastive, output_generate):
|
for output in (output_contrastive, output_generate):
|
||||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||||
|
|
||||||
def test_assisted_greedy_search_matches_greedy_search(self):
|
def test_assisted_decoding_matches_greedy_search(self):
|
||||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||||
# It breaks the pattern in the tests above, for multiple reasons:
|
# It breaks the pattern in the tests above, for multiple reasons:
|
||||||
# - assisted_greedy_search, contrarily to the other methods, can't be called on its own (e.g. needs to
|
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
|
||||||
# prepare the assistant encoder outputs in the main generate body);
|
# prepare the assistant encoder outputs in the main generate body);
|
||||||
# - assisted_greedy_search does not support `use_cache = False`
|
# - assisted_decoding does not support `use_cache = False`
|
||||||
# - assisted_greedy_search does not support `batch_size > 1`
|
# - assisted_decoding does not support `batch_size > 1`
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||||
return
|
return
|
||||||
# may fix in the future: the following models fail to pass this test, and need model-specific fixes
|
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
|
||||||
if any(
|
if any(
|
||||||
model_name in model_class.__name__.lower()
|
model_name in model_class.__name__.lower()
|
||||||
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text"]
|
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"]
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1517,6 +1517,46 @@ class GenerationTesterMixin:
|
|||||||
for output in (output_greedy, output_assisted):
|
for output in (output_greedy, output_assisted):
|
||||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||||
|
|
||||||
|
def test_assisted_decoding_sample(self):
|
||||||
|
# Seeded assisted decoding will not match sample for the same seed, as there are >1 sampling steps per output
|
||||||
|
# token. As such, this test only checks that the output format is correct.
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||||
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||||
|
return
|
||||||
|
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
|
||||||
|
if any(
|
||||||
|
model_name in model_class.__name__.lower()
|
||||||
|
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text", "git", "prophetnet"]
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# enable cache
|
||||||
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
|
||||||
|
|
||||||
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
|
if not hasattr(config, "use_cache"):
|
||||||
|
return
|
||||||
|
|
||||||
|
config.use_cache = True
|
||||||
|
config.is_decoder = True
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
output_assisted = model.generate(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
max_length=max_length,
|
||||||
|
num_beams=1,
|
||||||
|
do_sample=True,
|
||||||
|
assistant_model=model, # triggers assisted decoding
|
||||||
|
output_scores=True,
|
||||||
|
output_hidden_states=True,
|
||||||
|
output_attentions=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._check_outputs(output_assisted, input_ids, model.config, use_cache=True)
|
||||||
|
|
||||||
def test_generate_with_head_masking(self):
|
def test_generate_with_head_masking(self):
|
||||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||||
|
|||||||
Reference in New Issue
Block a user