Generate: Add assisted generation (#22211)
* working mvp * remove breakpoint * fix commit * standardize outputs * tmp commit * tests almost ready * tmp commit * skip a few models * Add streaming; Docs and examples * document limitations * PR commits * Amy PR comments
This commit is contained in:
@@ -332,3 +332,30 @@ The groups are selected to ensure they are distinct enough compared to the other
|
|||||||
This guide illustrates the main parameters that enable various decoding strategies. More advanced parameters exist for the
|
This guide illustrates the main parameters that enable various decoding strategies. More advanced parameters exist for the
|
||||||
[`generate`] method, which gives you even further control over the [`generate`] method's behavior.
|
[`generate`] method, which gives you even further control over the [`generate`] method's behavior.
|
||||||
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx).
|
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx).
|
||||||
|
|
||||||
|
### Assisted Generation
|
||||||
|
|
||||||
|
Assisted generation is a modification of the decoding strategies above that uses an assistant model with the same
|
||||||
|
tokenizer (ideally a much smaller model) to speed up the decoding process. Currently only assisted greedy search is
|
||||||
|
supported, and doesn't support batched inputs.
|
||||||
|
|
||||||
|
<!-- TODO: add link to the blog post about assisted generation when it exists -->
|
||||||
|
|
||||||
|
To enable assisted generation, set the `assistant_model` argument with a model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
>>> prompt = "Alice and Bob"
|
||||||
|
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
|
||||||
|
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
|
||||||
|
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
||||||
|
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
|
||||||
|
>>> outputs = model.generate(**inputs, assistant_model=assistant_model)
|
||||||
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
|
||||||
|
```
|
||||||
|
|||||||
@@ -73,9 +73,9 @@ from .stopping_criteria import (
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from ..modeling_utils import PreTrainedModel
|
||||||
from .streamers import BaseStreamer
|
from .streamers import BaseStreamer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -1146,6 +1146,7 @@ class GenerationMixin:
|
|||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
synced_gpus: Optional[bool] = None,
|
synced_gpus: Optional[bool] = None,
|
||||||
|
assistant_model: Optional["PreTrainedModel"] = None,
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
@@ -1196,10 +1197,14 @@ class GenerationMixin:
|
|||||||
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
|
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
|
||||||
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
|
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
|
||||||
generating before other GPUs. Otherwise it'll be set to `False`.
|
generating before other GPUs. Otherwise it'll be set to `False`.
|
||||||
|
assistant_model (`PreTrainedModel`, *optional*):
|
||||||
|
An assistant model that can be used to accelerate generation. The assistant model must have the exact
|
||||||
|
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
|
||||||
|
is much faster than running generation with the model you're calling generate from. As such, the
|
||||||
|
assistant model should be much smaller.
|
||||||
streamer (`BaseStreamer`, *optional*):
|
streamer (`BaseStreamer`, *optional*):
|
||||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||||
|
|
||||||
kwargs:
|
kwargs:
|
||||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||||
@@ -1411,6 +1416,14 @@ class GenerationMixin:
|
|||||||
and not is_constraint_gen_mode
|
and not is_constraint_gen_mode
|
||||||
and not is_contrastive_search_gen_mode
|
and not is_contrastive_search_gen_mode
|
||||||
)
|
)
|
||||||
|
is_assisted_greedy_gen_mode = False
|
||||||
|
if assistant_model is not None:
|
||||||
|
if not is_greedy_gen_mode:
|
||||||
|
raise ValueError(
|
||||||
|
"You've set `assistant_model`, which triggers assisted generation. Currently, assisted generation "
|
||||||
|
"is only supported with Greedy Search."
|
||||||
|
)
|
||||||
|
is_assisted_greedy_gen_mode = True
|
||||||
|
|
||||||
if generation_config.num_beam_groups > generation_config.num_beams:
|
if generation_config.num_beam_groups > generation_config.num_beams:
|
||||||
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
||||||
@@ -1449,11 +1462,47 @@ class GenerationMixin:
|
|||||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||||
)
|
)
|
||||||
# 10. go into different generation modes
|
# 10. go into different generation modes
|
||||||
|
if is_assisted_greedy_gen_mode:
|
||||||
|
if generation_config.num_return_sequences > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"num_return_sequences has to be 1 when doing assisted greedy search, "
|
||||||
|
f"but is {generation_config.num_return_sequences}."
|
||||||
|
)
|
||||||
|
if batch_size > 1:
|
||||||
|
raise ValueError("Assisted generation is only supported for batch_size = 1")
|
||||||
|
if not model_kwargs["use_cache"]:
|
||||||
|
raise ValueError("Assisted generation requires `use_cache=True`")
|
||||||
|
|
||||||
|
# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
|
||||||
|
if assistant_model.config.is_encoder_decoder:
|
||||||
|
assistant_model_kwargs = copy.deepcopy(model_kwargs)
|
||||||
|
inputs_tensor, model_input_name, assistant_model_kwargs = assistant_model._prepare_model_inputs(
|
||||||
|
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_model_kwargs
|
||||||
|
)
|
||||||
|
assistant_model_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
|
||||||
|
inputs_tensor, assistant_model_kwargs, model_input_name
|
||||||
|
)
|
||||||
|
model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"]
|
||||||
|
|
||||||
|
# 12. run assisted greedy search
|
||||||
|
return self.assisted_greedy_search(
|
||||||
|
input_ids,
|
||||||
|
assistant_model=assistant_model,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
|
pad_token_id=generation_config.pad_token_id,
|
||||||
|
eos_token_id=generation_config.eos_token_id,
|
||||||
|
output_scores=generation_config.output_scores,
|
||||||
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||||
|
synced_gpus=synced_gpus,
|
||||||
|
streamer=streamer,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
if is_greedy_gen_mode:
|
if is_greedy_gen_mode:
|
||||||
if generation_config.num_return_sequences > 1:
|
if generation_config.num_return_sequences > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
|
"num_return_sequences has to be 1 when doing greedy search, "
|
||||||
" greedy search."
|
f"but is {generation_config.num_return_sequences}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# 11. run greedy search
|
# 11. run greedy search
|
||||||
@@ -1473,9 +1522,11 @@ class GenerationMixin:
|
|||||||
elif is_contrastive_search_gen_mode:
|
elif is_contrastive_search_gen_mode:
|
||||||
if generation_config.num_return_sequences > 1:
|
if generation_config.num_return_sequences > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
|
"num_return_sequences has to be 1 when doing contrastive search, "
|
||||||
" contrastive search."
|
f"but is {generation_config.num_return_sequences}."
|
||||||
)
|
)
|
||||||
|
if not model_kwargs["use_cache"]:
|
||||||
|
raise ValueError("Contrastive search requires `use_cache=True`")
|
||||||
|
|
||||||
return self.contrastive_search(
|
return self.contrastive_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -1745,7 +1796,7 @@ class GenerationMixin:
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
return_dict_in_generate: Optional[bool] = None,
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
synced_gpus: Optional[bool] = False,
|
synced_gpus: bool = False,
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
|
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
|
||||||
@@ -2112,7 +2163,7 @@ class GenerationMixin:
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
return_dict_in_generate: Optional[bool] = None,
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
synced_gpus: Optional[bool] = False,
|
synced_gpus: bool = False,
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[GreedySearchOutput, torch.LongTensor]:
|
) -> Union[GreedySearchOutput, torch.LongTensor]:
|
||||||
@@ -2368,7 +2419,7 @@ class GenerationMixin:
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
return_dict_in_generate: Optional[bool] = None,
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
synced_gpus: Optional[bool] = False,
|
synced_gpus: bool = False,
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[SampleOutput, torch.LongTensor]:
|
) -> Union[SampleOutput, torch.LongTensor]:
|
||||||
@@ -2646,7 +2697,7 @@ class GenerationMixin:
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
return_dict_in_generate: Optional[bool] = None,
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
synced_gpus: Optional[bool] = False,
|
synced_gpus: bool = False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[BeamSearchOutput, torch.LongTensor]:
|
) -> Union[BeamSearchOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -2970,7 +3021,7 @@ class GenerationMixin:
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
return_dict_in_generate: Optional[bool] = None,
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
synced_gpus: Optional[bool] = False,
|
synced_gpus: bool = False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[BeamSampleOutput, torch.LongTensor]:
|
) -> Union[BeamSampleOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -3302,7 +3353,7 @@ class GenerationMixin:
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
output_scores: Optional[bool] = None,
|
output_scores: Optional[bool] = None,
|
||||||
return_dict_in_generate: Optional[bool] = None,
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
synced_gpus: Optional[bool] = False,
|
synced_gpus: bool = False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -3994,6 +4045,468 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
return sequence_outputs["sequences"]
|
return sequence_outputs["sequences"]
|
||||||
|
|
||||||
|
def assisted_greedy_search(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
assistant_model: "PreTrainedModel",
|
||||||
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
|
pad_token_id: Optional[int] = None,
|
||||||
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_scores: Optional[bool] = None,
|
||||||
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
|
synced_gpus: bool = False,
|
||||||
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
|
**model_kwargs,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Generates sequences of token ids for models with a language modeling head using **greedy decoding**, assisted
|
||||||
|
by a smaller model. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
In most cases, you do not need to call [`~generation.GenerationMixin.assisted_greedy_search`] directly. Use
|
||||||
|
generate() instead. For an overview of generation strategies and code examples, check the [following
|
||||||
|
guide](../generation_strategies).
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
The sequence used as a prompt for the generation.
|
||||||
|
assistant_model (`PreTrainedModel`, *optional*):
|
||||||
|
An assistant model that can be used to accelerate generation. The assistant model must have the exact
|
||||||
|
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
|
||||||
|
is much faster than running generation with the model you're calling generate from. As such, the
|
||||||
|
assistant model should be much smaller.
|
||||||
|
logits_processor (`LogitsProcessorList`, *optional*):
|
||||||
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||||
|
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||||
|
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
||||||
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
||||||
|
used to tell if the generation loop should stop.
|
||||||
|
pad_token_id (`int`, *optional*):
|
||||||
|
The id of the *padding* token.
|
||||||
|
eos_token_id (`Union[int, List[int]]`, *optional*):
|
||||||
|
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||||
|
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more details.
|
||||||
|
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||||
|
for more details.
|
||||||
|
output_scores (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
||||||
|
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||||
|
streamer (`BaseStreamer`, *optional*):
|
||||||
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||||
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||||
|
model_kwargs:
|
||||||
|
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
||||||
|
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
[`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or
|
||||||
|
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
||||||
|
[`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
||||||
|
`return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if
|
||||||
|
`model.config.is_encoder_decoder=True`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import (
|
||||||
|
... AutoTokenizer,
|
||||||
|
... AutoModelForCausalLM,
|
||||||
|
... LogitsProcessorList,
|
||||||
|
... MinLengthLogitsProcessor,
|
||||||
|
... StoppingCriteriaList,
|
||||||
|
... MaxLengthCriteria,
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||||
|
>>> assistant_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
||||||
|
>>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
|
||||||
|
>>> model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||||
|
>>> input_prompt = "It might be possible to"
|
||||||
|
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
|
||||||
|
>>> # instantiate logits processors
|
||||||
|
>>> logits_processor = LogitsProcessorList(
|
||||||
|
... [
|
||||||
|
... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
|
||||||
|
... ]
|
||||||
|
... )
|
||||||
|
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
||||||
|
>>> outputs = model.assisted_greedy_search(
|
||||||
|
... input_ids,
|
||||||
|
... assistant_model=assistant_model,
|
||||||
|
... logits_processor=logits_processor,
|
||||||
|
... stopping_criteria=stopping_criteria,
|
||||||
|
... )
|
||||||
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
|
||||||
|
```"""
|
||||||
|
# NOTE: the code here is copy/paste from greedy search, except when clearly stated in the comments
|
||||||
|
# Assistant: initialize assistant-related variables
|
||||||
|
if not hasattr(assistant_model, "max_assistant_tokens"):
|
||||||
|
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
|
||||||
|
|
||||||
|
# init values
|
||||||
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
|
if eos_token_id is not None and pad_token_id is None:
|
||||||
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||||
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict_in_generate = (
|
||||||
|
return_dict_in_generate
|
||||||
|
if return_dict_in_generate is not None
|
||||||
|
else self.generation_config.return_dict_in_generate
|
||||||
|
)
|
||||||
|
|
||||||
|
# init attention / hidden states / scores tuples
|
||||||
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||||
|
|
||||||
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||||
|
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||||
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||||
|
encoder_hidden_states = (
|
||||||
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# keep track of which sequences are already finished
|
||||||
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||||
|
|
||||||
|
this_peer_finished = False # used by synced_gpus only
|
||||||
|
while True:
|
||||||
|
if synced_gpus:
|
||||||
|
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||||
|
# The following logic allows an early break if all peers finished generating their sequence
|
||||||
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||||
|
# send 0.0 if we finished, 1.0 otherwise
|
||||||
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||||
|
# did all peers finish? the reduced sum will be 0.0 then
|
||||||
|
if this_peer_finished_flag.item() == 0.0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Assistant: main logic start
|
||||||
|
cur_len = input_ids.shape[-1]
|
||||||
|
max_len = stopping_criteria[0].max_length
|
||||||
|
|
||||||
|
# 1. Forecast next N tokens using the assistant model. This `for` block can be replaced with a
|
||||||
|
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
|
||||||
|
# need access to the assistant cache to secure strong speedups.
|
||||||
|
candidate_input_ids = input_ids
|
||||||
|
for _ in range(int(assistant_model.max_assistant_tokens)):
|
||||||
|
# 1.1. use the assistant model to obtain the next candidate logits
|
||||||
|
if "assistant_past_key_values" in model_kwargs:
|
||||||
|
prev_seq_len = model_kwargs["assistant_past_key_values"][0][0].shape[2]
|
||||||
|
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
|
||||||
|
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
|
||||||
|
tmp_inputs = candidate_input_ids[:, -new_token_len:]
|
||||||
|
tmp_attn = torch.ones_like(candidate_input_ids)
|
||||||
|
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
|
||||||
|
if assistant_model.config.is_encoder_decoder:
|
||||||
|
assistant_model_outputs = assistant_model(
|
||||||
|
decoder_input_ids=tmp_inputs,
|
||||||
|
decoder_attention_mask=tmp_attn,
|
||||||
|
past_key_values=model_kwargs["assistant_past_key_values"],
|
||||||
|
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assistant_model_outputs = assistant_model(
|
||||||
|
tmp_inputs,
|
||||||
|
attention_mask=tmp_attn,
|
||||||
|
past_key_values=model_kwargs["assistant_past_key_values"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if assistant_model.config.is_encoder_decoder:
|
||||||
|
assistant_model_outputs = assistant_model(
|
||||||
|
decoder_input_ids=candidate_input_ids,
|
||||||
|
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assistant_model_outputs = assistant_model(candidate_input_ids)
|
||||||
|
|
||||||
|
# 1.2. greedily select the next candidate token
|
||||||
|
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
|
||||||
|
if len(logits_processor) > 0:
|
||||||
|
assistant_model_outputs.logits[:, -1, :] = logits_processor(
|
||||||
|
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
|
||||||
|
)
|
||||||
|
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
|
||||||
|
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
|
||||||
|
|
||||||
|
# 1.3. stop assistant generation on EOS
|
||||||
|
if eos_token_id_tensor is not None:
|
||||||
|
last_assistant_token_is_eos = new_token.tile(eos_token_id_tensor.shape[0], 1)
|
||||||
|
last_assistant_token_is_eos = (
|
||||||
|
~last_assistant_token_is_eos.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
|
||||||
|
)
|
||||||
|
if last_assistant_token_is_eos:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
last_assistant_token_is_eos = False
|
||||||
|
|
||||||
|
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
||||||
|
|
||||||
|
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
|
||||||
|
# `candidate_length + 1` relevant logits from this process (see step 7 on why the +1)
|
||||||
|
if "past_key_values" in model_kwargs:
|
||||||
|
og_model_attn = torch.ones_like(candidate_input_ids)
|
||||||
|
og_model_input_ids = candidate_input_ids[:, -candidate_length - 1 :]
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
outputs = self(
|
||||||
|
decoder_input_ids=og_model_input_ids,
|
||||||
|
decoder_attention_mask=og_model_attn,
|
||||||
|
past_key_values=model_kwargs["past_key_values"],
|
||||||
|
encoder_outputs=model_kwargs["encoder_outputs"],
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outputs = self(
|
||||||
|
og_model_input_ids,
|
||||||
|
attention_mask=og_model_attn,
|
||||||
|
past_key_values=model_kwargs["past_key_values"],
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
outputs = self(
|
||||||
|
decoder_input_ids=candidate_input_ids,
|
||||||
|
encoder_outputs=model_kwargs["encoder_outputs"],
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outputs = self(
|
||||||
|
candidate_input_ids,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Obtain the argmax from the original model logits.
|
||||||
|
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
|
||||||
|
if len(logits_processor) > 0:
|
||||||
|
for i in range(candidate_length):
|
||||||
|
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
||||||
|
max_logits = new_logits.argmax(dim=-1)[:, -candidate_length - 1 : -1]
|
||||||
|
|
||||||
|
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
|
||||||
|
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
|
||||||
|
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
|
||||||
|
n_matches = ((~(candidate_new_tokens == max_logits)).cumsum(dim=-1) < 1).sum()
|
||||||
|
|
||||||
|
# 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
||||||
|
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
|
||||||
|
# cost of forecasting incorrect assistant tokens.
|
||||||
|
if n_matches == int(assistant_model.max_assistant_tokens):
|
||||||
|
assistant_model.max_assistant_tokens += 2.0
|
||||||
|
else:
|
||||||
|
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
|
||||||
|
|
||||||
|
# 6. Update variables according to the number of matching assistant tokens.
|
||||||
|
# 6.1. Ensure we don't generate beyond max_len or an EOS token (remember: one token will be added below)
|
||||||
|
n_matches = min(n_matches, max_len - cur_len - 1)
|
||||||
|
if last_assistant_token_is_eos and n_matches == candidate_length:
|
||||||
|
n_matches -= 1
|
||||||
|
input_ids = candidate_input_ids[:, 0 : cur_len + n_matches]
|
||||||
|
new_cur_len = input_ids.shape[-1]
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.put(candidate_input_ids[:, cur_len : cur_len + n_matches])
|
||||||
|
|
||||||
|
# 6.2. Discard past key values relative to unused assistant tokens
|
||||||
|
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cur_len)
|
||||||
|
model_kwargs["assistant_past_key_values"] = _crop_past_key_values(
|
||||||
|
assistant_model, model_kwargs["assistant_past_key_values"], new_cur_len
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6.3. Extract the logits for the next token
|
||||||
|
next_token_scores = new_logits[:, n_matches, :]
|
||||||
|
|
||||||
|
# 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that,
|
||||||
|
# because of this step, assisted greedy search reduces to a normal greedy search if there is no match.
|
||||||
|
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
||||||
|
|
||||||
|
# Assistant: main logic end; Compared to greedy search, the following (redundant) blocks were removed
|
||||||
|
# below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model cache
|
||||||
|
# update.
|
||||||
|
|
||||||
|
if synced_gpus and this_peer_finished:
|
||||||
|
continue # don't waste resources running the code we don't need
|
||||||
|
|
||||||
|
# Store scores, attentions and hidden_states when required
|
||||||
|
# Assistant: modified to append one tuple element per token, as in the other generation methods.
|
||||||
|
if return_dict_in_generate:
|
||||||
|
if output_scores:
|
||||||
|
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
|
||||||
|
|
||||||
|
if "past_key_values" not in model_kwargs:
|
||||||
|
last_matching_idx = new_cur_len - 1
|
||||||
|
prompt_length = cur_len
|
||||||
|
else:
|
||||||
|
last_matching_idx = n_matches
|
||||||
|
prompt_length = 0
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
cross_attentions = _split_model_outputs(
|
||||||
|
cross_attentions, outputs.cross_attentions, prompt_length, last_matching_idx
|
||||||
|
)
|
||||||
|
decoder_attentions = _split_model_outputs(
|
||||||
|
decoder_attentions,
|
||||||
|
outputs.decoder_attentions,
|
||||||
|
prompt_length,
|
||||||
|
last_matching_idx,
|
||||||
|
is_decoder_attention=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
decoder_attentions = _split_model_outputs(
|
||||||
|
decoder_attentions,
|
||||||
|
outputs.attentions,
|
||||||
|
prompt_length,
|
||||||
|
last_matching_idx,
|
||||||
|
is_decoder_attention=True,
|
||||||
|
)
|
||||||
|
if output_hidden_states:
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
decoder_hidden_states = _split_model_outputs(
|
||||||
|
decoder_hidden_states, outputs.decoder_hidden_states, prompt_length, last_matching_idx
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
decoder_hidden_states = _split_model_outputs(
|
||||||
|
decoder_hidden_states, outputs.hidden_states, prompt_length, last_matching_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
# finished sentences should have their next token be a padding token
|
||||||
|
if eos_token_id is not None:
|
||||||
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||||
|
|
||||||
|
# update generated ids, model inputs, and length for next step
|
||||||
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.put(next_tokens.cpu())
|
||||||
|
|
||||||
|
model_kwargs = self._update_model_kwargs_for_generation(
|
||||||
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||||
|
)
|
||||||
|
|
||||||
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
|
if eos_token_id_tensor is not None:
|
||||||
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# stop when each sentence is finished, or if we exceed the maximum length
|
||||||
|
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||||
|
if not synced_gpus:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
this_peer_finished = True
|
||||||
|
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.end()
|
||||||
|
|
||||||
|
if return_dict_in_generate:
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
return GreedySearchEncoderDecoderOutput(
|
||||||
|
sequences=input_ids,
|
||||||
|
scores=scores,
|
||||||
|
encoder_attentions=encoder_attentions,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
decoder_attentions=decoder_attentions,
|
||||||
|
cross_attentions=cross_attentions,
|
||||||
|
decoder_hidden_states=decoder_hidden_states,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return GreedySearchDecoderOnlyOutput(
|
||||||
|
sequences=input_ids,
|
||||||
|
scores=scores,
|
||||||
|
attentions=decoder_attentions,
|
||||||
|
hidden_states=decoder_hidden_states,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _crop_past_key_values(model, past_key_values, maximum_length):
|
||||||
|
"""Crops the past key values up to a certain maximum length."""
|
||||||
|
new_past = []
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
for idx in range(len(past_key_values)):
|
||||||
|
new_past.append(
|
||||||
|
(
|
||||||
|
past_key_values[idx][0][:, :, :maximum_length, :],
|
||||||
|
past_key_values[idx][1][:, :, :maximum_length, :],
|
||||||
|
past_key_values[idx][2],
|
||||||
|
past_key_values[idx][3],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
past_key_values = tuple(new_past)
|
||||||
|
elif "bloom" in model.__class__.__name__.lower(): # bloom is special
|
||||||
|
for idx in range(len(past_key_values)):
|
||||||
|
new_past.append(
|
||||||
|
(
|
||||||
|
past_key_values[idx][0][:, :, :maximum_length],
|
||||||
|
past_key_values[idx][1][:, :maximum_length, :],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
past_key_values = tuple(new_past)
|
||||||
|
else:
|
||||||
|
for idx in range(len(past_key_values)):
|
||||||
|
new_past.append(
|
||||||
|
(
|
||||||
|
past_key_values[idx][0][:, :, :maximum_length, :],
|
||||||
|
past_key_values[idx][1][:, :, :maximum_length, :],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
past_key_values = tuple(new_past)
|
||||||
|
return past_key_values
|
||||||
|
|
||||||
|
|
||||||
|
def _split_model_outputs(outputs, new_outputs, prompt_length, last_matching_idx, is_decoder_attention=False):
|
||||||
|
"""
|
||||||
|
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
|
||||||
|
where each member corresponds to a single generated token.
|
||||||
|
"""
|
||||||
|
# Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the
|
||||||
|
# prompt.
|
||||||
|
if prompt_length > 0:
|
||||||
|
new_tuple = ()
|
||||||
|
for layer in new_outputs:
|
||||||
|
last_dim_size = prompt_length if is_decoder_attention else layer.shape[-1]
|
||||||
|
new_tuple += (layer[..., :prompt_length, :last_dim_size],)
|
||||||
|
outputs += (new_tuple,)
|
||||||
|
|
||||||
|
for i in range(prompt_length, last_matching_idx + 1):
|
||||||
|
new_tuple = ()
|
||||||
|
for layer in new_outputs:
|
||||||
|
last_dim_size = i + 1 if is_decoder_attention else layer.shape[-1]
|
||||||
|
new_tuple += (layer[..., i : i + 1, :last_dim_size],)
|
||||||
|
outputs += (new_tuple,)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_filtering(
|
def top_k_top_p_filtering(
|
||||||
logits: torch.FloatTensor,
|
logits: torch.FloatTensor,
|
||||||
|
|||||||
@@ -79,14 +79,13 @@ class GenerationTesterMixin:
|
|||||||
all_generative_model_classes = ()
|
all_generative_model_classes = ()
|
||||||
input_name = "input_ids"
|
input_name = "input_ids"
|
||||||
|
|
||||||
def _get_input_ids_and_config(self):
|
def _get_input_ids_and_config(self, batch_size=2):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
input_ids = inputs_dict[self.input_name]
|
input_ids = inputs_dict[self.input_name]
|
||||||
|
|
||||||
# cut to half length & take max batch_size 3
|
# cut to half length & take max batch_size 3
|
||||||
max_batch_size = 2
|
|
||||||
sequence_length = input_ids.shape[-1] // 2
|
sequence_length = input_ids.shape[-1] // 2
|
||||||
input_ids = input_ids[:max_batch_size, :sequence_length]
|
input_ids = input_ids[:batch_size, :sequence_length]
|
||||||
|
|
||||||
# generate max 3 tokens
|
# generate max 3 tokens
|
||||||
max_length = input_ids.shape[-1] + 3
|
max_length = input_ids.shape[-1] + 3
|
||||||
@@ -99,7 +98,7 @@ class GenerationTesterMixin:
|
|||||||
if "transfoxl" in config.__class__.__name__.lower():
|
if "transfoxl" in config.__class__.__name__.lower():
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
else:
|
else:
|
||||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:max_batch_size, :sequence_length]
|
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length]
|
||||||
|
|
||||||
return config, input_ids, attention_mask, max_length
|
return config, input_ids, attention_mask, max_length
|
||||||
|
|
||||||
@@ -1458,6 +1457,66 @@ class GenerationTesterMixin:
|
|||||||
for output in (output_contrastive, output_generate):
|
for output in (output_contrastive, output_generate):
|
||||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||||
|
|
||||||
|
def test_assisted_greedy_search_matches_greedy_search(self):
|
||||||
|
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||||
|
# It breaks the pattern in the tests above, for multiple reasons:
|
||||||
|
# - assisted_greedy_search, contrarily to the other methods, can't be called on its own (e.g. needs to
|
||||||
|
# prepare the assistant encoder outputs in the main generate body);
|
||||||
|
# - assisted_greedy_search does not support `use_cache = False`
|
||||||
|
# - assisted_greedy_search does not support `batch_size > 1`
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||||
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||||
|
return
|
||||||
|
# may fix in the future: the following models fail to pass this test, and need model-specific fixes
|
||||||
|
if any(
|
||||||
|
model_name in model_class.__name__.lower()
|
||||||
|
for model_name in ["bigbirdpegasus", "gptbigcode", "led", "mega", "speech2text"]
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# enable cache
|
||||||
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
|
||||||
|
|
||||||
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
|
if not hasattr(config, "use_cache"):
|
||||||
|
return
|
||||||
|
|
||||||
|
config.use_cache = True
|
||||||
|
config.is_decoder = True
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
output_greedy = model.generate(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
max_length=max_length,
|
||||||
|
num_beams=1,
|
||||||
|
do_sample=False,
|
||||||
|
output_scores=True,
|
||||||
|
output_hidden_states=True,
|
||||||
|
output_attentions=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
)
|
||||||
|
# Note: with assisted generate, if the same model is used as assistant, then all assistant tokens will
|
||||||
|
# be correct
|
||||||
|
output_assisted = model.generate(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
max_length=max_length,
|
||||||
|
num_beams=1,
|
||||||
|
do_sample=False,
|
||||||
|
assistant_model=model,
|
||||||
|
output_scores=True,
|
||||||
|
output_hidden_states=True,
|
||||||
|
output_attentions=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
|
||||||
|
|
||||||
|
for output in (output_greedy, output_assisted):
|
||||||
|
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||||
|
|
||||||
def test_generate_with_head_masking(self):
|
def test_generate_with_head_masking(self):
|
||||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||||
|
|||||||
@@ -280,7 +280,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
|
|
||||||
# overwrite from GenerationTesterMixin to solve problem
|
# overwrite from GenerationTesterMixin to solve problem
|
||||||
# with conflicting random seeds
|
# with conflicting random seeds
|
||||||
def _get_input_ids_and_config(self):
|
def _get_input_ids_and_config(self, batch_size=2):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
config.attention_type = "original_full"
|
config.attention_type = "original_full"
|
||||||
|
|
||||||
@@ -288,10 +288,9 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||||
|
|
||||||
# cut to half length & take max batch_size 3
|
# cut to half length & take max batch_size 3
|
||||||
max_batch_size = 2
|
|
||||||
sequence_length = input_ids.shape[-1] // 2
|
sequence_length = input_ids.shape[-1] // 2
|
||||||
input_ids = input_ids[:max_batch_size, :sequence_length]
|
input_ids = input_ids[:batch_size, :sequence_length]
|
||||||
attention_mask = attention_mask[:max_batch_size, :sequence_length]
|
attention_mask = attention_mask[:batch_size, :sequence_length]
|
||||||
|
|
||||||
# generate max 3 tokens
|
# generate max 3 tokens
|
||||||
max_length = input_ids.shape[-1] + 3
|
max_length = input_ids.shape[-1] + 3
|
||||||
|
|||||||
@@ -303,7 +303,7 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
input_ids = input_ids[:max_batch_size, :, :]
|
input_ids = input_ids[:max_batch_size, :, :]
|
||||||
|
|
||||||
# generate max 3 tokens
|
# generate max 3 tokens
|
||||||
max_length = input_ids.shape[-1] + 3
|
max_length = 4
|
||||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||||
config.pad_token_id = config.eos_token_id
|
config.pad_token_id = config.eos_token_id
|
||||||
|
|||||||
@@ -359,16 +359,15 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
|
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
|
||||||
|
|
||||||
def _get_input_ids_and_config(self):
|
def _get_input_ids_and_config(self, batch_size=3):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
input_ids = inputs_dict[self.input_name]
|
input_ids = inputs_dict[self.input_name]
|
||||||
|
|
||||||
# cut to half length & take max batch_size 3
|
# cut to half length & take max batch_size=batch_size
|
||||||
max_batch_size = 3
|
input_ids = input_ids[:batch_size, :, :]
|
||||||
input_ids = input_ids[:max_batch_size, :, :]
|
|
||||||
|
|
||||||
# generate max 3 tokens
|
# generate max 3 tokens
|
||||||
max_length = input_ids.shape[-1] + 3
|
max_length = 4
|
||||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||||
config.pad_token_id = config.eos_token_id
|
config.pad_token_id = config.eos_token_id
|
||||||
|
|||||||
Reference in New Issue
Block a user