Generate: assisted decoding now uses generate for the assistant (#28030)

generate refactor
This commit is contained in:
Joao Gante
2023-12-14 13:31:13 +00:00
committed by GitHub
parent dde6c427a1
commit 9e5c28c573
3 changed files with 116 additions and 120 deletions

View File

@@ -15,7 +15,7 @@
import copy import copy
import warnings import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch import torch
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
class CandidateGenerator: class CandidateGenerator:
"""Abstract base class for all candidate generators that can be applied during assisted generation.""" """Abstract base class for all candidate generators that can be applied during assisted generation."""
def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
""" """
Fetches the candidates to be tried for the current input. Fetches the candidates to be tried for the current input.
@@ -37,8 +37,9 @@ class CandidateGenerator:
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return: Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be assessed by `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
the model. assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
""" """
raise NotImplementedError( raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`." f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`."
@@ -152,7 +153,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
) )
self.logits_processor = logits_processor self.logits_processor = logits_processor
def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
""" """
Fetches the candidates to be tried for the current input. Fetches the candidates to be tried for the current input.
@@ -161,7 +162,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return: Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
""" """
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round) # (which implicitly contains the number of accepted candidates from the previous round)
@@ -179,51 +182,24 @@ class AssistedCandidateGenerator(CandidateGenerator):
) )
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)
# 2. Forecast next N tokens using the assistant model. This `for` block can be replaced with a `.generate()` # 2. Forecast next N tokens using the assistant model.
# call if we decide to add `past_key_values` as a possible output of generate, as we need access to the assistant_generation_kwargs = {
# assistant cache to secure strong speedups. self.input_ids_key: input_ids,
candidate_input_ids = input_ids "do_sample": False,
for _ in range(int(self.num_assistant_tokens)): "num_beams": 1,
# 2.1 prepare assistant model inputs "max_new_tokens": int(self.num_assistant_tokens),
assistant_inputs = self.assistant_model.prepare_inputs_for_generation( "return_dict_in_generate": True,
candidate_input_ids, "output_scores": True,
**self.assistant_kwargs, }
) assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
# 2.2. check if the input ids length is correct # 3. Update variables for the next round of candidate generation
has_past_key_values = assistant_inputs.get("past_key_values", None) is not None self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
if has_past_key_values and assistant_inputs[self.input_ids_key].shape[-1] not in (1, 2):
raise ValueError("The length of the input ids in assistant inputs should be 1 or 2")
# 2.3. use the assistant model to obtain the next candidate logits # 4. Prepare variables for output
assistant_model_outputs = self.assistant_model(**assistant_inputs) candidate_logits = torch.stack(assistant_output.scores, dim=1)
candidate_ids = assistant_output.sequences
# 2.4. greedily select the next candidate token return candidate_ids, candidate_logits
if len(self.logits_processor) > 0:
assistant_model_outputs.logits[:, -1, :] = self.logits_processor(
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
)
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
# 2.5. update assistant model inputs
if self.assistant_kwargs.get(self.attention_key, None) is not None:
mask = self.assistant_kwargs[self.attention_key]
self.assistant_kwargs[self.attention_key] = torch.cat(
[mask, mask.new_ones((mask.shape[0], 1))], dim=-1
)
self.assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values
# 2.6. stop assistant generation on EOS
if self.eos_token_id_tensor is not None:
last_assistant_token_is_eos = new_token.tile(self.eos_token_id_tensor.shape[0], 1)
last_assistant_token_is_eos = (
~last_assistant_token_is_eos.ne(self.eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
)
if last_assistant_token_is_eos:
break
return candidate_input_ids
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
""" """

View File

@@ -4585,7 +4585,7 @@ class GenerationMixin:
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
# 1. Fetch candidate sequences from a `CandidateGenerator` # 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids = candidate_generator.get_candidates(input_ids) candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
last_assistant_token_is_eos = ( last_assistant_token_is_eos = (
~candidate_input_ids[:, -1] ~candidate_input_ids[:, -1]

View File

@@ -3128,85 +3128,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist()) self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())
def test_model_kwarg_assisted_decoding_encoder_decoder(self): def test_model_kwarg_assisted_decoding_encoder_decoder(self):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. encoder-decoder assistant model
3. both have a custom input
(e.g. Whisper)
"""
# PT-only test: TF doesn't support assisted decoding yet. # PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg that distorts the output # Bart subclass with a kwarg that distorts the output
class FakeBart(BartForConditionalGeneration): class FakeBart(BartForConditionalGeneration):
def forward(self, input_ids, foo=False, **kwargs): def forward(self, input_ids, past_key_values, foo=False, **kwargs):
outs = super().forward(input_ids, **kwargs) outs = super().forward(input_ids, past_key_values=past_key_values, **kwargs)
if foo: if foo:
outs["logits"][:, :, :] = 0.0 outs["logits"][:, :, :] = 0.0
return outs return outs
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs) inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
return inputs
model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
# Traditional way of generating text
outputs_normal = model.generate(input_ids)
self.assertEqual(outputs_normal.shape, (1, 20))
# Should be different with foo
outputs_foo = model.generate(
input_ids,
foo=True,
)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
# Assistant model
assistant = AutoModelForSeq2SeqLM.from_pretrained(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).to(torch_device)
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted = model.generate(
input_ids,
foo=True,
assistant_model=assistant,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
# Check that passing encoder_outputs directly also works as expected
encoder_outputs = assistant.get_encoder()(input_ids)
outputs_assisted = model.generate(
foo=True,
assistant_model=assistant,
encoder_outputs=encoder_outputs,
assistant_encoder_outputs=encoder_outputs,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
def test_assisted_decoding_encoder_decoder_shared_encoder(self):
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg called foo that distorts the output
class FakeBart(BartForConditionalGeneration):
def forward(self, input_ids, foo=False, **kwargs):
outs = super().forward(input_ids, **kwargs)
if foo:
outs["logits"][:, :, :] = 0.0
return outs
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo inputs["foo"] = foo
return inputs return inputs
@@ -3229,7 +3170,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
# Assistant model # Assistant model
assistant = BartForCausalLM.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device torch_device
) )
@@ -3241,6 +3182,85 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
) )
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
# Check that passing encoder_outputs directly also works as expected
encoder_outputs = assistant.get_encoder()(input_ids)
outputs_assisted = model.generate(
foo=True,
assistant_model=assistant,
encoder_outputs=encoder_outputs,
assistant_encoder_outputs=encoder_outputs,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
def test_assisted_decoding_encoder_decoder_shared_encoder(self):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. decoder-only assistant model
3. both have a custom input
(e.g. DistilWhisper)
"""
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg called foo that distorts the output
class FakeBartSeq2Seq(BartForConditionalGeneration):
def forward(self, input_ids, foo=False, **kwargs):
outs = super().forward(input_ids, **kwargs)
if foo:
outs["logits"][:, :, :] = 0.0
return outs
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
return inputs
class FakeBartCausalLM(BartForCausalLM):
def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs):
outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs)
if foo:
outs["logits"][:, :, :] = 0.0
return outs
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
return inputs
model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
# Traditional way of generating text
outputs_normal = model.generate(input_ids)
self.assertEqual(outputs_normal.shape, (1, 20))
# Should be different with foo
outputs_foo = model.generate(input_ids, foo=True)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
# Assistant model
assistant = FakeBartCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).to(torch_device)
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted = model.generate(
input_ids,
foo=True,
assistant_model=assistant,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
# Check that passing encoder_outputs directly also works as expected # Check that passing encoder_outputs directly also works as expected
encoder_outputs = model.get_encoder()(input_ids) encoder_outputs = model.get_encoder()(input_ids)