Revert "add attention_mask and position_ids in assisted model" (#27523)
* Revert "add attention_mask and position_ids in assisted model (#26892)"
This reverts commit 184f60dcec.
* more debug
This commit is contained in:
committed by
GitHub
parent
4989e73e2f
commit
5603fad247
@@ -4504,6 +4504,11 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
||||||
|
|
||||||
|
# check if assistant model accepts encoder_outputs
|
||||||
|
assistant_accepts_encoder_outputs = "encoder_outputs" in set(
|
||||||
|
inspect.signature(assistant_model.forward).parameters.keys()
|
||||||
|
)
|
||||||
|
|
||||||
# init values
|
# init values
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||||
@@ -4546,6 +4551,15 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# other auxiliary variables
|
# other auxiliary variables
|
||||||
max_len = stopping_criteria[0].max_length
|
max_len = stopping_criteria[0].max_length
|
||||||
|
assistant_kv_indexing = (
|
||||||
|
1
|
||||||
|
if "bloom" in assistant_model.__class__.__name__.lower()
|
||||||
|
or (
|
||||||
|
assistant_model.config.architectures is not None
|
||||||
|
and "bloom" in assistant_model.config.architectures[0].lower()
|
||||||
|
)
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
this_peer_finished = False # used by synced_gpus only
|
this_peer_finished = False # used by synced_gpus only
|
||||||
while True:
|
while True:
|
||||||
@@ -4566,28 +4580,42 @@ class GenerationMixin:
|
|||||||
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
|
# `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we
|
||||||
# need access to the assistant cache to secure strong speedups.
|
# need access to the assistant cache to secure strong speedups.
|
||||||
candidate_input_ids = input_ids
|
candidate_input_ids = input_ids
|
||||||
assistant_attention_mask = model_kwargs.get("attention_mask", None)
|
|
||||||
assistant_decoder_attention_mask = model_kwargs.get("decoder_attention_mask", None)
|
|
||||||
assistant_encoder_outputs = (model_kwargs.get("assistant_encoder_outputs", None),)
|
|
||||||
for _ in range(int(num_assistant_tokens)):
|
for _ in range(int(num_assistant_tokens)):
|
||||||
# 1.1. use the assistant model to obtain the next candidate logits
|
# 1.1. use the assistant model to obtain the next candidate logits
|
||||||
assistant_inputs = assistant_model.prepare_inputs_for_generation(
|
if "assistant_past_key_values" in model_kwargs:
|
||||||
candidate_input_ids,
|
prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
|
||||||
attention_mask=assistant_attention_mask,
|
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
|
||||||
decoder_attention_mask=assistant_decoder_attention_mask,
|
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
|
||||||
encoder_outputs=assistant_encoder_outputs,
|
assist_inputs = candidate_input_ids[:, -new_token_len:]
|
||||||
past_key_values=model_kwargs.get("assistant_past_key_values", None),
|
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
|
||||||
)
|
|
||||||
if assistant_inputs.get("past_key_values", None) is not None:
|
|
||||||
if assistant_model.config.is_encoder_decoder:
|
if assistant_model.config.is_encoder_decoder:
|
||||||
input_ids_len = assistant_inputs["decoder_input_ids"].shape[-1]
|
assistant_model_outputs = assistant_model(
|
||||||
|
decoder_input_ids=assist_inputs,
|
||||||
|
past_key_values=model_kwargs["assistant_past_key_values"],
|
||||||
|
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
input_ids_len = assistant_inputs["input_ids"].shape[-1]
|
encoder_kwargs = {}
|
||||||
|
|
||||||
if input_ids_len not in (1, 2):
|
if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
|
||||||
raise ValueError("The length of the input ids in assistant inputs should be 1 or 2")
|
encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
|
||||||
|
|
||||||
assistant_model_outputs = assistant_model(**assistant_inputs)
|
assistant_model_outputs = assistant_model(
|
||||||
|
assist_inputs, past_key_values=model_kwargs["assistant_past_key_values"], **encoder_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if assistant_model.config.is_encoder_decoder:
|
||||||
|
assistant_model_outputs = assistant_model(
|
||||||
|
decoder_input_ids=candidate_input_ids,
|
||||||
|
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_kwargs = {}
|
||||||
|
|
||||||
|
if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
|
||||||
|
encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
|
||||||
|
|
||||||
|
assistant_model_outputs = assistant_model(candidate_input_ids, **encoder_kwargs)
|
||||||
|
|
||||||
# 1.2. greedily select the next candidate token
|
# 1.2. greedily select the next candidate token
|
||||||
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
|
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
|
||||||
@@ -4595,31 +4623,8 @@ class GenerationMixin:
|
|||||||
assistant_model_outputs.logits[:, -1, :] = logits_processor(
|
assistant_model_outputs.logits[:, -1, :] = logits_processor(
|
||||||
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
|
candidate_input_ids, assistant_model_outputs.logits[:, -1, :]
|
||||||
)
|
)
|
||||||
|
|
||||||
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
|
new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1)
|
||||||
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
|
candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1)
|
||||||
if assistant_model.config.is_encoder_decoder and assistant_decoder_attention_mask is not None:
|
|
||||||
assistant_decoder_attention_mask = torch.cat(
|
|
||||||
(
|
|
||||||
assistant_decoder_attention_mask,
|
|
||||||
torch.ones(
|
|
||||||
[1, 1],
|
|
||||||
dtype=assistant_decoder_attention_mask.dtype,
|
|
||||||
device=assistant_decoder_attention_mask.device,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
elif not assistant_model.config.is_encoder_decoder and assistant_attention_mask is not None:
|
|
||||||
assistant_attention_mask = torch.cat(
|
|
||||||
(
|
|
||||||
assistant_attention_mask,
|
|
||||||
torch.ones(
|
|
||||||
[1, 1], dtype=assistant_attention_mask.dtype, device=assistant_attention_mask.device
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1.3. stop assistant generation on EOS
|
# 1.3. stop assistant generation on EOS
|
||||||
if eos_token_id_tensor is not None:
|
if eos_token_id_tensor is not None:
|
||||||
@@ -4755,13 +4760,6 @@ class GenerationMixin:
|
|||||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update attention_mask for the assistant's next round of generations
|
|
||||||
if n_matches > 0 and model_kwargs.get("attention_mask", None) is not None:
|
|
||||||
attention_mask = model_kwargs["attention_mask"]
|
|
||||||
model_kwargs["attention_mask"] = torch.cat(
|
|
||||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], n_matches))], dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
if eos_token_id_tensor is not None:
|
if eos_token_id_tensor is not None:
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import copy
|
|||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -1736,6 +1737,102 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(prompt in text)
|
self.assertTrue(prompt in text)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_speculative_decoding_distil(self):
|
||||||
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||||
|
model_id = "openai/whisper-large-v2"
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained(
|
||||||
|
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
||||||
|
)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
processor = WhisperProcessor.from_pretrained(model_id)
|
||||||
|
|
||||||
|
assistant_model_id = "distil-whisper/distil-large-v2"
|
||||||
|
assistant_model = WhisperForCausalLM.from_pretrained(
|
||||||
|
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
||||||
|
)
|
||||||
|
assistant_model.to(torch_device)
|
||||||
|
|
||||||
|
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
sample = dataset[0]["audio"]
|
||||||
|
|
||||||
|
input_features = processor(sample["array"], return_tensors="pt").input_features.to("cuda").to(torch.float16)
|
||||||
|
|
||||||
|
# warm up assisted decoding
|
||||||
|
_ = model.generate(input_features, assistant_model=assistant_model)
|
||||||
|
# warm up non-assisted decoding
|
||||||
|
_ = model.generate(input_features)
|
||||||
|
|
||||||
|
# assisted decoding
|
||||||
|
start_time = time.time()
|
||||||
|
tokens = model.generate(input_features, assistant_model=assistant_model)
|
||||||
|
total_time_assist = time.time() - start_time
|
||||||
|
|
||||||
|
transcription_ass = processor.batch_decode(tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# non-assisted decoding
|
||||||
|
start_time = time.time()
|
||||||
|
tokens = model.generate(input_features)
|
||||||
|
total_time_non_assist = time.time() - start_time
|
||||||
|
|
||||||
|
transcription_non_ass = processor.batch_decode(tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
assert transcription_ass == transcription_non_ass
|
||||||
|
assert transcription_ass == [
|
||||||
|
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
|
||||||
|
]
|
||||||
|
assert total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster"
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_speculative_decoding_non_distil(self):
|
||||||
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||||
|
model_id = "openai/whisper-large-v2"
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained(
|
||||||
|
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
||||||
|
)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
processor = WhisperProcessor.from_pretrained(model_id)
|
||||||
|
|
||||||
|
assistant_model_id = "openai/whisper-tiny"
|
||||||
|
assistant_model = WhisperForConditionalGeneration.from_pretrained(
|
||||||
|
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
||||||
|
)
|
||||||
|
assistant_model.to(torch_device)
|
||||||
|
|
||||||
|
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
sample = dataset[0]["audio"]
|
||||||
|
|
||||||
|
input_features = processor(sample["array"], return_tensors="pt").input_features.to("cuda").to(torch.float16)
|
||||||
|
|
||||||
|
# warm up assisted decoding
|
||||||
|
_ = model.generate(input_features, assistant_model=assistant_model)
|
||||||
|
# warm up non-assisted decoding
|
||||||
|
_ = model.generate(input_features)
|
||||||
|
|
||||||
|
# assisted decoding
|
||||||
|
start_time = time.time()
|
||||||
|
tokens = model.generate(input_features, assistant_model=assistant_model)
|
||||||
|
total_time_assist = time.time() - start_time
|
||||||
|
|
||||||
|
transcription_ass = processor.batch_decode(tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# non-assisted decoding
|
||||||
|
start_time = time.time()
|
||||||
|
tokens = model.generate(input_features)
|
||||||
|
total_time_non_assist = time.time() - start_time
|
||||||
|
|
||||||
|
transcription_non_ass = processor.batch_decode(tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
assert transcription_ass == transcription_non_ass
|
||||||
|
assert transcription_ass == [
|
||||||
|
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
|
||||||
|
]
|
||||||
|
assert total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster"
|
||||||
|
|
||||||
|
|
||||||
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
||||||
if head_mask is None:
|
if head_mask is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user