From 391d14e8105cf877797e95a92534db89077ccb7e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 1 Nov 2023 16:01:53 +0100 Subject: [PATCH] [WhisperForCausalLM] Add WhisperForCausalLM for speculative decoding (#27195) * finish * add tests * fix all tests * [Assistant Decoding] Add test * fix more * better * finish * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * finish --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- docs/source/en/model_doc/whisper.md | 5 + docs/source/en/tasks/language_modeling.md | 2 +- src/transformers/__init__.py | 2 + src/transformers/generation/utils.py | 35 ++- src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/models/whisper/__init__.py | 2 + .../models/whisper/modeling_whisper.py | 247 +++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 7 + tests/generation/test_utils.py | 61 +++++ tests/models/whisper/test_modeling_whisper.py | 244 +++++++++++++++++ 10 files changed, 601 insertions(+), 5 deletions(-) diff --git a/docs/source/en/model_doc/whisper.md b/docs/source/en/model_doc/whisper.md index 382246e43d..2f1cfc5e22 100644 --- a/docs/source/en/model_doc/whisper.md +++ b/docs/source/en/model_doc/whisper.md @@ -88,6 +88,11 @@ The original code can be found [here](https://github.com/openai/whisper). - forward - generate +## WhisperForCausalLM + +[[autodoc]] WhisperForCausalLM + - forward + ## WhisperForAudioClassification [[autodoc]] WhisperForAudioClassification diff --git a/docs/source/en/tasks/language_modeling.md b/docs/source/en/tasks/language_modeling.md index c509899882..9c35b7293d 100644 --- a/docs/source/en/tasks/language_modeling.md +++ b/docs/source/en/tasks/language_modeling.md @@ -37,7 +37,7 @@ You can finetune other architectures for causal language modeling following the Choose one of the following architectures: -[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod) +[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a784cff7c8..01127a5651 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3090,6 +3090,7 @@ else: [ "WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST", "WhisperForAudioClassification", + "WhisperForCausalLM", "WhisperForConditionalGeneration", "WhisperModel", "WhisperPreTrainedModel", @@ -6845,6 +6846,7 @@ if TYPE_CHECKING: from .models.whisper import ( WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST, WhisperForAudioClassification, + WhisperForCausalLM, WhisperForConditionalGeneration, WhisperModel, WhisperPreTrainedModel, diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1c412f8185..df4239d054 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1626,6 +1626,10 @@ class GenerationMixin: if not model_kwargs["use_cache"]: raise ValueError("assisted generate requires `use_cache=True`") + assistant_accepts_encoder_outputs = "encoder_outputs" in set( + inspect.signature(assistant_model.forward).parameters.keys() + ) + # 11. If the assistant model is an encoder-decoder, prepare its encoder outputs if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs: assistant_model_kwargs = copy.deepcopy(model_kwargs) @@ -1637,6 +1641,17 @@ class GenerationMixin: ) model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"] + if ( + not assistant_model.config.is_encoder_decoder + and assistant_accepts_encoder_outputs + and "encoder_outputs" in model_kwargs + ): + # some assistants might be assymetric (many more enc layers than dec layers) + # encoder-decoder models that share the exact same encoder as the teacher + # in this case the assistant only needs to load the light-weight decoder, + # but still requires `encoder_outputs` to be passed + model_kwargs["assistant_encoder_outputs"] = model_kwargs["encoder_outputs"] + # 12. run assisted generate return self.assisted_decoding( input_ids, @@ -4368,6 +4383,11 @@ class GenerationMixin: else: num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens + # check if assistant model accepts encoder_outputs + assistant_accepts_encoder_outputs = "encoder_outputs" in set( + inspect.signature(assistant_model.forward).parameters.keys() + ) + # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() @@ -4454,9 +4474,13 @@ class GenerationMixin: encoder_outputs=model_kwargs["assistant_encoder_outputs"], ) else: + encoder_kwargs = {} + + if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs: + encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] + assistant_model_outputs = assistant_model( - assist_inputs, - past_key_values=model_kwargs["assistant_past_key_values"], + assist_inputs, past_key_values=model_kwargs["assistant_past_key_values"], **encoder_kwargs ) else: if assistant_model.config.is_encoder_decoder: @@ -4465,7 +4489,12 @@ class GenerationMixin: encoder_outputs=model_kwargs["assistant_encoder_outputs"], ) else: - assistant_model_outputs = assistant_model(candidate_input_ids) + encoder_kwargs = {} + + if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs: + encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] + + assistant_model_outputs = assistant_model(candidate_input_ids, **encoder_kwargs) # 1.2. greedily select the next candidate token model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 3c622c8158..5387809ca4 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -438,6 +438,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("speech_to_text_2", "Speech2Text2ForCausalLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("trocr", "TrOCRForCausalLM"), + ("whisper", "WhisperForCausalLM"), ("xglm", "XGLMForCausalLM"), ("xlm", "XLMWithLMHeadModel"), ("xlm-prophetnet", "XLMProphetNetForCausalLM"), diff --git a/src/transformers/models/whisper/__init__.py b/src/transformers/models/whisper/__init__.py index cd962478e3..d87828da69 100644 --- a/src/transformers/models/whisper/__init__.py +++ b/src/transformers/models/whisper/__init__.py @@ -46,6 +46,7 @@ except OptionalDependencyNotAvailable: else: _import_structure["modeling_whisper"] = [ "WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST", + "WhisperForCausalLM", "WhisperForConditionalGeneration", "WhisperModel", "WhisperPreTrainedModel", @@ -102,6 +103,7 @@ if TYPE_CHECKING: from .modeling_whisper import ( WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST, WhisperForAudioClassification, + WhisperForCausalLM, WhisperForConditionalGeneration, WhisperModel, WhisperPreTrainedModel, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 8df937e3e6..48f47fe12d 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -29,6 +29,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, SequenceClassifierOutput, @@ -945,6 +946,8 @@ class WhisperDecoder(WhisperPreTrainedModel): config: WhisperConfig """ + main_input_name = "input_ids" + def __init__(self, config: WhisperConfig): super().__init__(config) self.dropout = config.dropout @@ -1028,7 +1031,8 @@ class WhisperDecoder(WhisperPreTrainedModel): If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal @@ -1811,6 +1815,247 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): return timestamps +class WhisperDecoderWrapper(WhisperPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + config.is_encoder_decoder = False + self.decoder = WhisperDecoder(config) + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +@add_start_docstrings( + """ + Whisper decoder with with a language modeling head on top (linear layer with weights tied to the input embeddings). + """, + WHISPER_START_DOCSTRING, +) +class WhisperForCausalLM(WhisperPreTrainedModel): + _tied_weights_keys = ["proj_out.weight"] + main_input_name = "input_ids" + + def __init__(self, config): + super().__init__(config) + config.is_encoder_decoder = False + self.model = WhisperDecoderWrapper(config) + + self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.proj_out + + def set_output_embeddings(self, new_embeddings): + self.proj_out = new_embeddings + + def get_input_embeddings(self) -> nn.Module: + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + encoder_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. Contains + pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If + `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import WhisperForCausalLM, WhisperForConditionalGeneration, WhisperProcessor + >>> import torch + >>> from datasets import load_dataset + + >>> processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2") + + >>> assistant_model = WhisperForCausalLM.from_pretrained("distil-whisper/distil-large-v2") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> sample = ds[0]["audio"] + >>> input_features = processor( + ... sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt" + ... ).input_features + + >>> predicted_ids = model.generate(input_features, assistant_model=assistant_model) + + >>> # decode token ids to text + >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) + >>> transcription + ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.' + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # If the user passed a tuple or `BaseModelOutput` for encoder_outputs, we extract only the hidden states + if isinstance(encoder_outputs, (BaseModelOutput, tuple, list)): + encoder_outputs = encoder_outputs[0] + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_outputs, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.proj_out(outputs[0]) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + use_cache=None, + encoder_outputs=None, + attention_mask=None, + **kwargs, + ): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + return { + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "input_ids": input_ids, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + @add_start_docstrings( """ Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 1310312519..8ce211b21f 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8413,6 +8413,13 @@ class WhisperForAudioClassification(metaclass=DummyObject): requires_backends(self, ["torch"]) +class WhisperForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class WhisperForConditionalGeneration(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 42b67be91a..6468973d67 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -43,6 +43,7 @@ if is_torch_available(): AutoModelForSpeechSeq2Seq, AutoModelForVision2Seq, AutoTokenizer, + BartForCausalLM, BartForConditionalGeneration, BartTokenizer, GPT2LMHeadModel, @@ -3010,3 +3011,63 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi assistant_encoder_outputs=encoder_outputs, ) self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) + + def test_assisted_decoding_encoder_decoder_shared_encoder(self): + # PT-only test: TF doesn't support assisted decoding yet. + # Bart subclass with a kwarg called foo that distorts the output + class FakeBart(BartForConditionalGeneration): + def forward(self, input_ids, foo=False, **kwargs): + outs = super().forward(input_ids, **kwargs) + + if foo: + outs["logits"][:, :, :] = 0.0 + + return outs + + def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs): + kwargs["encoder_outputs"] = encoder_outputs + inputs = super().prepare_inputs_for_generation(*args, **kwargs) + + inputs["foo"] = foo + return inputs + + model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( + torch_device + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Traditional way of generating text + outputs_normal = model.generate(input_ids) + self.assertEqual(outputs_normal.shape, (1, 20)) + + # Should be different with foo + outputs_foo = model.generate(input_ids, foo=True) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) + + # Assistant model + assistant = BartForCausalLM.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( + torch_device + ) + + # If assisted generation passes model_kwargs correctly, should be same as previous + outputs_assisted = model.generate( + input_ids, + foo=True, + assistant_model=assistant, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) + + # Check that passing encoder_outputs directly also works as expected + encoder_outputs = model.get_encoder()(input_ids) + + outputs_assisted = model.generate( + foo=True, + assistant_model=assistant, + encoder_outputs=encoder_outputs, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist()) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 9bb8353608..bc1a7bd218 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -51,6 +51,7 @@ if is_torch_available(): from transformers import ( WhisperFeatureExtractor, WhisperForAudioClassification, + WhisperForCausalLM, WhisperForConditionalGeneration, WhisperModel, WhisperProcessor, @@ -1990,3 +1991,246 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest. self.assertEqual(fx_keys, pt_keys) self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) + + +class WhisperStandaloneDecoderModelTester: + def __init__( + self, + parent, + batch_size=2, + is_training=True, + use_labels=False, + vocab_size=200, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + input_channels=1, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=20, + max_source_positions=30, + max_target_positions=40, + bos_token_id=98, + eos_token_id=98, + pad_token_id=0, + num_mel_bins=80, + decoder_start_token_id=85, + num_conv_layers=1, + suppress_tokens=None, + begin_suppress_tokens=None, + ): + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.input_channels = input_channels + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.num_mel_bins = num_mel_bins + self.max_position_embeddings = max_position_embeddings + self.max_source_positions = max_source_positions + self.max_target_positions = max_target_positions + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.decoder_start_token_id = decoder_start_token_id + self.num_conv_layers = num_conv_layers + self.suppress_tokens = suppress_tokens + self.begin_suppress_tokens = begin_suppress_tokens + + def prepare_config_and_inputs(self): + input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size) + + decoder_input_ids = torch.tensor( + self.batch_size * [[self.decoder_start_token_id, 3, 3, 7, 2]], device=torch_device + ) + + config = self.get_config() + config.is_encoder_decoder = False + inputs_dict = prepare_whisper_inputs_dict( + config, + attention_mask=None, + input_features=input_features, + decoder_input_ids=decoder_input_ids, + ) + + inputs_dict.pop("input_features") + inputs_dict.pop("head_mask") + inputs_dict.pop("decoder_head_mask") + inputs_dict.pop("cross_attn_head_mask") + + inputs_dict["attention_mask"] = inputs_dict.pop("decoder_attention_mask") + inputs_dict["input_ids"] = inputs_dict.pop("decoder_input_ids") + return config, inputs_dict + + @property + def encoder_seq_length(self): + return 5 + + @property + def seq_length(self): + return 5 + + def get_config(self): + return WhisperConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + input_channels=self.input_channels, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + max_source_positions=self.max_source_positions, + max_target_positions=self.max_target_positions, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + pad_token_id=self.pad_token_id, + decoder_ffn_dim=self.hidden_size, + encoder_ffn_dim=self.hidden_size, + decoder_start_token_id=self.decoder_start_token_id, + suppress_tokens=self.suppress_tokens, + begin_suppress_tokens=self.begin_suppress_tokens, + ) + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + + inputs_dict["input_ids"][:, -1] = self.pad_token_id + + return config, inputs_dict + + def prepare_config_and_inputs_for_decoder(self): + config, input_features = self.prepare_config_and_inputs() + input_ids = input_features["input_ids"] + encoder_hidden_states = floats_tensor([self.batch_size, self.decoder_seq_length, self.hidden_size]) + + return (config, input_ids, encoder_hidden_states) + + def create_and_check_decoder_model_past(self, config, input_ids): + config.use_cache = True + model = WhisperDecoder(config=config).to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def create_and_check_decoder_model_attention_mask_past(self, config, input_ids): + model = WhisperDecoder(config=config).to(torch_device).eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + +@require_torch +class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (WhisperDecoder, WhisperForCausalLM) if is_torch_available() else () + all_generative_model_classes = (WhisperForCausalLM,) if is_torch_available() else () + fx_comptatible = False + test_pruning = False + is_encoder_decoder = False + test_missing_keys = False + + def setUp(self): + self.model_tester = WhisperStandaloneDecoderModelTester(self, is_training=False) + self.config_tester = ConfigTester(self, config_class=WhisperConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config, inputs_dict = config_and_inputs + + self.model_tester.create_and_check_decoder_model_past(config=config, input_ids=inputs_dict["input_ids"]) + + def test_decoder_model_attn_mask_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config, inputs_dict = config_and_inputs + + self.model_tester.create_and_check_decoder_model_attention_mask_past( + config=config, input_ids=inputs_dict["input_ids"] + ) + + @unittest.skip("Generate needs input ids") + def test_generate_without_input_ids(self): + # generate only works with input ids for whisper + pass + + @unittest.skip("Decoder can't keep attention grads") + def test_retain_grad_hidden_states_attentions(self): + # decoder cannot keep gradients + return + + @unittest.skip("The model doesn't support fast init from base") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :) + def test_left_padding_compatibility(self): + pass