Generate: assisted decoding now uses generate for the assistant (#28030)
generate refactor
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
|
||||
import copy
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
class CandidateGenerator:
|
||||
"""Abstract base class for all candidate generators that can be applied during assisted generation."""
|
||||
|
||||
def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor:
|
||||
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Fetches the candidates to be tried for the current input.
|
||||
|
||||
@@ -37,8 +37,9 @@ class CandidateGenerator:
|
||||
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||||
|
||||
Return:
|
||||
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be assessed by
|
||||
the model.
|
||||
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
|
||||
assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length,
|
||||
vocabulary_size)` containing the logits associated to each candidate.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`."
|
||||
@@ -152,7 +153,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
)
|
||||
self.logits_processor = logits_processor
|
||||
|
||||
def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor:
|
||||
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Fetches the candidates to be tried for the current input.
|
||||
|
||||
@@ -161,7 +162,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||||
|
||||
Return:
|
||||
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
|
||||
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
|
||||
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
|
||||
vocabulary_size)` containing the logits associated to each candidate.
|
||||
"""
|
||||
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
|
||||
# (which implicitly contains the number of accepted candidates from the previous round)
|
||||
@@ -179,51 +182,24 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
)
|
||||
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)
|
||||
|
||||
# 2. Forecast next N tokens using the assistant model. This `for` block can be replaced with a `.generate()`
|
||||
# call if we decide to add `past_key_values` as a possible output of generate, as we need access to the
|
||||
# assistant cache to secure strong speedups.
|
||||
candidate_input_ids = input_ids
|
||||
for _ in range(int(self.num_assistant_tokens)):
|
||||
# 2.1 prepare assistant model inputs
|
||||
assistant_inputs = self.assistant_model.prepare_inputs_for_generation(
|
||||
candidate_input_ids,
|
||||
**self.assistant_kwargs,
|
||||
)
|
||||
# 2. Forecast next N tokens using the assistant model.
|
||||
assistant_generation_kwargs = {
|
||||
self.input_ids_key: input_ids,
|
||||
"do_sample": False,
|
||||
"num_beams": 1,
|
||||
"max_new_tokens": int(self.num_assistant_tokens),
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
}
|
||||
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
|
||||
|
||||
# 2.2. check if the input ids length is correct
|
||||
has_past_key_values = assistant_inputs.get("past_key_values", None) is not None
|
||||
if has_past_key_values and assistant_inputs[self.input_ids_key].shape[-1] not in (1, 2):
|
||||
raise ValueError("The length of the input ids in assistant inputs should be 1 or 2")
|
||||
# 3. Update variables for the next round of candidate generation
|
||||
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
|
||||
|
||||
# 2.3. use the assistant model to obtain the next candidate logits
|
||||
assistant_model_outputs = self.assistant_model(**assistant_inputs)
|
||||
|
||||
# 2.4. greedily select the next candidate token
|
||||
if len(self.logits_processor) > 0:
|
||||
assistant_model_outputs.logits[:, -1, :] = self.logits_processor(
|
||||
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
|
||||
)
|
||||
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
|
||||
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
|
||||
|
||||
# 2.5. update assistant model inputs
|
||||
if self.assistant_kwargs.get(self.attention_key, None) is not None:
|
||||
mask = self.assistant_kwargs[self.attention_key]
|
||||
self.assistant_kwargs[self.attention_key] = torch.cat(
|
||||
[mask, mask.new_ones((mask.shape[0], 1))], dim=-1
|
||||
)
|
||||
self.assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values
|
||||
|
||||
# 2.6. stop assistant generation on EOS
|
||||
if self.eos_token_id_tensor is not None:
|
||||
last_assistant_token_is_eos = new_token.tile(self.eos_token_id_tensor.shape[0], 1)
|
||||
last_assistant_token_is_eos = (
|
||||
~last_assistant_token_is_eos.ne(self.eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool()
|
||||
)
|
||||
if last_assistant_token_is_eos:
|
||||
break
|
||||
|
||||
return candidate_input_ids
|
||||
# 4. Prepare variables for output
|
||||
candidate_logits = torch.stack(assistant_output.scores, dim=1)
|
||||
candidate_ids = assistant_output.sequences
|
||||
return candidate_ids, candidate_logits
|
||||
|
||||
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
|
||||
"""
|
||||
|
||||
@@ -4585,7 +4585,7 @@ class GenerationMixin:
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
# 1. Fetch candidate sequences from a `CandidateGenerator`
|
||||
candidate_input_ids = candidate_generator.get_candidates(input_ids)
|
||||
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
|
||||
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
||||
last_assistant_token_is_eos = (
|
||||
~candidate_input_ids[:, -1]
|
||||
|
||||
@@ -3128,85 +3128,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())
|
||||
|
||||
def test_model_kwarg_assisted_decoding_encoder_decoder(self):
|
||||
"""
|
||||
Tests that the following scenario is compatible with assisted generation:
|
||||
1. encoder-decoder main model
|
||||
2. encoder-decoder assistant model
|
||||
3. both have a custom input
|
||||
(e.g. Whisper)
|
||||
"""
|
||||
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
# Bart subclass with a kwarg that distorts the output
|
||||
class FakeBart(BartForConditionalGeneration):
|
||||
def forward(self, input_ids, foo=False, **kwargs):
|
||||
outs = super().forward(input_ids, **kwargs)
|
||||
|
||||
def forward(self, input_ids, past_key_values, foo=False, **kwargs):
|
||||
outs = super().forward(input_ids, past_key_values=past_key_values, **kwargs)
|
||||
if foo:
|
||||
outs["logits"][:, :, :] = 0.0
|
||||
|
||||
return outs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
inputs["foo"] = foo
|
||||
return inputs
|
||||
|
||||
model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
|
||||
torch_device
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_normal = model.generate(input_ids)
|
||||
self.assertEqual(outputs_normal.shape, (1, 20))
|
||||
|
||||
# Should be different with foo
|
||||
outputs_foo = model.generate(
|
||||
input_ids,
|
||||
foo=True,
|
||||
)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
|
||||
|
||||
# Assistant model
|
||||
assistant = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
|
||||
).to(torch_device)
|
||||
|
||||
# If assisted generation passes model_kwargs correctly, should be same as previous
|
||||
outputs_assisted = model.generate(
|
||||
input_ids,
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
# Check that passing encoder_outputs directly also works as expected
|
||||
encoder_outputs = assistant.get_encoder()(input_ids)
|
||||
|
||||
outputs_assisted = model.generate(
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
encoder_outputs=encoder_outputs,
|
||||
assistant_encoder_outputs=encoder_outputs,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
def test_assisted_decoding_encoder_decoder_shared_encoder(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
# Bart subclass with a kwarg called foo that distorts the output
|
||||
class FakeBart(BartForConditionalGeneration):
|
||||
def forward(self, input_ids, foo=False, **kwargs):
|
||||
outs = super().forward(input_ids, **kwargs)
|
||||
|
||||
if foo:
|
||||
outs["logits"][:, :, :] = 0.0
|
||||
|
||||
return outs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
inputs["foo"] = foo
|
||||
return inputs
|
||||
|
||||
@@ -3229,7 +3170,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
|
||||
|
||||
# Assistant model
|
||||
assistant = BartForCausalLM.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
|
||||
assistant = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
@@ -3241,6 +3182,85 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
# Check that passing encoder_outputs directly also works as expected
|
||||
encoder_outputs = assistant.get_encoder()(input_ids)
|
||||
|
||||
outputs_assisted = model.generate(
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
encoder_outputs=encoder_outputs,
|
||||
assistant_encoder_outputs=encoder_outputs,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
def test_assisted_decoding_encoder_decoder_shared_encoder(self):
|
||||
"""
|
||||
Tests that the following scenario is compatible with assisted generation:
|
||||
1. encoder-decoder main model
|
||||
2. decoder-only assistant model
|
||||
3. both have a custom input
|
||||
(e.g. DistilWhisper)
|
||||
"""
|
||||
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
# Bart subclass with a kwarg called foo that distorts the output
|
||||
class FakeBartSeq2Seq(BartForConditionalGeneration):
|
||||
def forward(self, input_ids, foo=False, **kwargs):
|
||||
outs = super().forward(input_ids, **kwargs)
|
||||
if foo:
|
||||
outs["logits"][:, :, :] = 0.0
|
||||
return outs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
inputs["foo"] = foo
|
||||
return inputs
|
||||
|
||||
class FakeBartCausalLM(BartForCausalLM):
|
||||
def forward(self, input_ids, attention_mask, past_key_values, foo=False, **kwargs):
|
||||
outs = super().forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs)
|
||||
if foo:
|
||||
outs["logits"][:, :, :] = 0.0
|
||||
return outs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
inputs["foo"] = foo
|
||||
return inputs
|
||||
|
||||
model = FakeBartSeq2Seq.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
|
||||
torch_device
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_normal = model.generate(input_ids)
|
||||
self.assertEqual(outputs_normal.shape, (1, 20))
|
||||
|
||||
# Should be different with foo
|
||||
outputs_foo = model.generate(input_ids, foo=True)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
|
||||
|
||||
# Assistant model
|
||||
assistant = FakeBartCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
|
||||
).to(torch_device)
|
||||
|
||||
# If assisted generation passes model_kwargs correctly, should be same as previous
|
||||
outputs_assisted = model.generate(
|
||||
input_ids,
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
# Check that passing encoder_outputs directly also works as expected
|
||||
encoder_outputs = model.get_encoder()(input_ids)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user