From 5603fad2479ad22ca4689f6a4dbf56ef2f1f0973 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 16 Nov 2023 14:50:39 +0100 Subject: [PATCH] Revert "add attention_mask and position_ids in assisted model" (#27523) * Revert "add attention_mask and position_ids in assisted model (#26892)" This reverts commit 184f60dcec6f7f664687a9e211e8d2216052b05d. * more debug --- src/transformers/generation/utils.py | 90 +++++++++-------- tests/models/whisper/test_modeling_whisper.py | 97 +++++++++++++++++++ 2 files changed, 141 insertions(+), 46 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 10ffffc37c..14e4b10129 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4504,6 +4504,11 @@ class GenerationMixin: else: num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens + # check if assistant model accepts encoder_outputs + assistant_accepts_encoder_outputs = "encoder_outputs" in set( + inspect.signature(assistant_model.forward).parameters.keys() + ) + # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() @@ -4546,6 +4551,15 @@ class GenerationMixin: # other auxiliary variables max_len = stopping_criteria[0].max_length + assistant_kv_indexing = ( + 1 + if "bloom" in assistant_model.__class__.__name__.lower() + or ( + assistant_model.config.architectures is not None + and "bloom" in assistant_model.config.architectures[0].lower() + ) + else 0 + ) this_peer_finished = False # used by synced_gpus only while True: @@ -4566,28 +4580,42 @@ class GenerationMixin: # `.generate()` call if we decide to add `past_key_values` as a possible output of generate, as we # need access to the assistant cache to secure strong speedups. candidate_input_ids = input_ids - assistant_attention_mask = model_kwargs.get("attention_mask", None) - assistant_decoder_attention_mask = model_kwargs.get("decoder_attention_mask", None) - assistant_encoder_outputs = (model_kwargs.get("assistant_encoder_outputs", None),) for _ in range(int(num_assistant_tokens)): # 1.1. use the assistant model to obtain the next candidate logits - assistant_inputs = assistant_model.prepare_inputs_for_generation( - candidate_input_ids, - attention_mask=assistant_attention_mask, - decoder_attention_mask=assistant_decoder_attention_mask, - encoder_outputs=assistant_encoder_outputs, - past_key_values=model_kwargs.get("assistant_past_key_values", None), - ) - if assistant_inputs.get("past_key_values", None) is not None: + if "assistant_past_key_values" in model_kwargs: + prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2] + # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) + new_token_len = candidate_input_ids.shape[1] - prev_seq_len + assist_inputs = candidate_input_ids[:, -new_token_len:] + # TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2 if assistant_model.config.is_encoder_decoder: - input_ids_len = assistant_inputs["decoder_input_ids"].shape[-1] + assistant_model_outputs = assistant_model( + decoder_input_ids=assist_inputs, + past_key_values=model_kwargs["assistant_past_key_values"], + encoder_outputs=model_kwargs["assistant_encoder_outputs"], + ) else: - input_ids_len = assistant_inputs["input_ids"].shape[-1] + encoder_kwargs = {} - if input_ids_len not in (1, 2): - raise ValueError("The length of the input ids in assistant inputs should be 1 or 2") + if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs: + encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] - assistant_model_outputs = assistant_model(**assistant_inputs) + assistant_model_outputs = assistant_model( + assist_inputs, past_key_values=model_kwargs["assistant_past_key_values"], **encoder_kwargs + ) + else: + if assistant_model.config.is_encoder_decoder: + assistant_model_outputs = assistant_model( + decoder_input_ids=candidate_input_ids, + encoder_outputs=model_kwargs["assistant_encoder_outputs"], + ) + else: + encoder_kwargs = {} + + if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs: + encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] + + assistant_model_outputs = assistant_model(candidate_input_ids, **encoder_kwargs) # 1.2. greedily select the next candidate token model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values @@ -4595,31 +4623,8 @@ class GenerationMixin: assistant_model_outputs.logits[:, -1, :] = logits_processor( candidate_input_ids, assistant_model_outputs.logits[:, -1, :] ) - new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) - if assistant_model.config.is_encoder_decoder and assistant_decoder_attention_mask is not None: - assistant_decoder_attention_mask = torch.cat( - ( - assistant_decoder_attention_mask, - torch.ones( - [1, 1], - dtype=assistant_decoder_attention_mask.dtype, - device=assistant_decoder_attention_mask.device, - ), - ), - dim=-1, - ) - elif not assistant_model.config.is_encoder_decoder and assistant_attention_mask is not None: - assistant_attention_mask = torch.cat( - ( - assistant_attention_mask, - torch.ones( - [1, 1], dtype=assistant_attention_mask.dtype, device=assistant_attention_mask.device - ), - ), - dim=-1, - ) # 1.3. stop assistant generation on EOS if eos_token_id_tensor is not None: @@ -4755,13 +4760,6 @@ class GenerationMixin: outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) - # Update attention_mask for the assistant's next round of generations - if n_matches > 0 and model_kwargs.get("attention_mask", None) is not None: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], n_matches))], dim=-1 - ) - # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 22290bab66..c7d6fb6926 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -18,6 +18,7 @@ import copy import inspect import os import tempfile +import time import unittest import numpy as np @@ -1736,6 +1737,102 @@ class WhisperModelIntegrationTests(unittest.TestCase): self.assertTrue(prompt in text) + @slow + @require_torch_gpu + def test_speculative_decoding_distil(self): + torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + model_id = "openai/whisper-large-v2" + model = WhisperForConditionalGeneration.from_pretrained( + model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True + ) + model.to(torch_device) + + processor = WhisperProcessor.from_pretrained(model_id) + + assistant_model_id = "distil-whisper/distil-large-v2" + assistant_model = WhisperForCausalLM.from_pretrained( + assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True + ) + assistant_model.to(torch_device) + + dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + sample = dataset[0]["audio"] + + input_features = processor(sample["array"], return_tensors="pt").input_features.to("cuda").to(torch.float16) + + # warm up assisted decoding + _ = model.generate(input_features, assistant_model=assistant_model) + # warm up non-assisted decoding + _ = model.generate(input_features) + + # assisted decoding + start_time = time.time() + tokens = model.generate(input_features, assistant_model=assistant_model) + total_time_assist = time.time() - start_time + + transcription_ass = processor.batch_decode(tokens, skip_special_tokens=True) + + # non-assisted decoding + start_time = time.time() + tokens = model.generate(input_features) + total_time_non_assist = time.time() - start_time + + transcription_non_ass = processor.batch_decode(tokens, skip_special_tokens=True) + + assert transcription_ass == transcription_non_ass + assert transcription_ass == [ + " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." + ] + assert total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster" + + @slow + @require_torch_gpu + def test_speculative_decoding_non_distil(self): + torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + model_id = "openai/whisper-large-v2" + model = WhisperForConditionalGeneration.from_pretrained( + model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True + ) + model.to(torch_device) + + processor = WhisperProcessor.from_pretrained(model_id) + + assistant_model_id = "openai/whisper-tiny" + assistant_model = WhisperForConditionalGeneration.from_pretrained( + assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True + ) + assistant_model.to(torch_device) + + dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + sample = dataset[0]["audio"] + + input_features = processor(sample["array"], return_tensors="pt").input_features.to("cuda").to(torch.float16) + + # warm up assisted decoding + _ = model.generate(input_features, assistant_model=assistant_model) + # warm up non-assisted decoding + _ = model.generate(input_features) + + # assisted decoding + start_time = time.time() + tokens = model.generate(input_features, assistant_model=assistant_model) + total_time_assist = time.time() - start_time + + transcription_ass = processor.batch_decode(tokens, skip_special_tokens=True) + + # non-assisted decoding + start_time = time.time() + tokens = model.generate(input_features) + total_time_non_assist = time.time() - start_time + + transcription_non_ass = processor.batch_decode(tokens, skip_special_tokens=True) + + assert transcription_ass == transcription_non_ass + assert transcription_ass == [ + " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." + ] + assert total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster" + def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): if head_mask is None: