[WhisperForCausalLM] Add WhisperForCausalLM for speculative decoding (#27195)
* finish * add tests * fix all tests * [Assistant Decoding] Add test * fix more * better * finish * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * finish --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
f9b4bea0a6
commit
391d14e810
@@ -88,6 +88,11 @@ The original code can be found [here](https://github.com/openai/whisper).
|
||||
- forward
|
||||
- generate
|
||||
|
||||
## WhisperForCausalLM
|
||||
|
||||
[[autodoc]] WhisperForCausalLM
|
||||
- forward
|
||||
|
||||
## WhisperForAudioClassification
|
||||
|
||||
[[autodoc]] WhisperForAudioClassification
|
||||
|
||||
@@ -37,7 +37,7 @@ You can finetune other architectures for causal language modeling following the
|
||||
Choose one of the following architectures:
|
||||
|
||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
|
||||
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -3090,6 +3090,7 @@ else:
|
||||
[
|
||||
"WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"WhisperForAudioClassification",
|
||||
"WhisperForCausalLM",
|
||||
"WhisperForConditionalGeneration",
|
||||
"WhisperModel",
|
||||
"WhisperPreTrainedModel",
|
||||
@@ -6845,6 +6846,7 @@ if TYPE_CHECKING:
|
||||
from .models.whisper import (
|
||||
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
WhisperForAudioClassification,
|
||||
WhisperForCausalLM,
|
||||
WhisperForConditionalGeneration,
|
||||
WhisperModel,
|
||||
WhisperPreTrainedModel,
|
||||
|
||||
@@ -1626,6 +1626,10 @@ class GenerationMixin:
|
||||
if not model_kwargs["use_cache"]:
|
||||
raise ValueError("assisted generate requires `use_cache=True`")
|
||||
|
||||
assistant_accepts_encoder_outputs = "encoder_outputs" in set(
|
||||
inspect.signature(assistant_model.forward).parameters.keys()
|
||||
)
|
||||
|
||||
# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
|
||||
if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs:
|
||||
assistant_model_kwargs = copy.deepcopy(model_kwargs)
|
||||
@@ -1637,6 +1641,17 @@ class GenerationMixin:
|
||||
)
|
||||
model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"]
|
||||
|
||||
if (
|
||||
not assistant_model.config.is_encoder_decoder
|
||||
and assistant_accepts_encoder_outputs
|
||||
and "encoder_outputs" in model_kwargs
|
||||
):
|
||||
# some assistants might be assymetric (many more enc layers than dec layers)
|
||||
# encoder-decoder models that share the exact same encoder as the teacher
|
||||
# in this case the assistant only needs to load the light-weight decoder,
|
||||
# but still requires `encoder_outputs` to be passed
|
||||
model_kwargs["assistant_encoder_outputs"] = model_kwargs["encoder_outputs"]
|
||||
|
||||
# 12. run assisted generate
|
||||
return self.assisted_decoding(
|
||||
input_ids,
|
||||
@@ -4368,6 +4383,11 @@ class GenerationMixin:
|
||||
else:
|
||||
num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
||||
|
||||
# check if assistant model accepts encoder_outputs
|
||||
assistant_accepts_encoder_outputs = "encoder_outputs" in set(
|
||||
inspect.signature(assistant_model.forward).parameters.keys()
|
||||
)
|
||||
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
@@ -4454,9 +4474,13 @@ class GenerationMixin:
|
||||
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||
)
|
||||
else:
|
||||
encoder_kwargs = {}
|
||||
|
||||
if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
|
||||
encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
|
||||
|
||||
assistant_model_outputs = assistant_model(
|
||||
assist_inputs,
|
||||
past_key_values=model_kwargs["assistant_past_key_values"],
|
||||
assist_inputs, past_key_values=model_kwargs["assistant_past_key_values"], **encoder_kwargs
|
||||
)
|
||||
else:
|
||||
if assistant_model.config.is_encoder_decoder:
|
||||
@@ -4465,7 +4489,12 @@ class GenerationMixin:
|
||||
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||
)
|
||||
else:
|
||||
assistant_model_outputs = assistant_model(candidate_input_ids)
|
||||
encoder_kwargs = {}
|
||||
|
||||
if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
|
||||
encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
|
||||
|
||||
assistant_model_outputs = assistant_model(candidate_input_ids, **encoder_kwargs)
|
||||
|
||||
# 1.2. greedily select the next candidate token
|
||||
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
|
||||
|
||||
@@ -438,6 +438,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("speech_to_text_2", "Speech2Text2ForCausalLM"),
|
||||
("transfo-xl", "TransfoXLLMHeadModel"),
|
||||
("trocr", "TrOCRForCausalLM"),
|
||||
("whisper", "WhisperForCausalLM"),
|
||||
("xglm", "XGLMForCausalLM"),
|
||||
("xlm", "XLMWithLMHeadModel"),
|
||||
("xlm-prophetnet", "XLMProphetNetForCausalLM"),
|
||||
|
||||
@@ -46,6 +46,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["modeling_whisper"] = [
|
||||
"WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"WhisperForCausalLM",
|
||||
"WhisperForConditionalGeneration",
|
||||
"WhisperModel",
|
||||
"WhisperPreTrainedModel",
|
||||
@@ -102,6 +103,7 @@ if TYPE_CHECKING:
|
||||
from .modeling_whisper import (
|
||||
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
WhisperForAudioClassification,
|
||||
WhisperForCausalLM,
|
||||
WhisperForConditionalGeneration,
|
||||
WhisperModel,
|
||||
WhisperPreTrainedModel,
|
||||
|
||||
@@ -29,6 +29,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
@@ -945,6 +946,8 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
config: WhisperConfig
|
||||
"""
|
||||
|
||||
main_input_name = "input_ids"
|
||||
|
||||
def __init__(self, config: WhisperConfig):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
@@ -1028,7 +1031,8 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
||||
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||||
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
|
||||
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of
|
||||
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
|
||||
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
|
||||
control over how to convert `input_ids` indices into associated vectors than the model's internal
|
||||
@@ -1811,6 +1815,247 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
return timestamps
|
||||
|
||||
|
||||
class WhisperDecoderWrapper(WhisperPreTrainedModel):
|
||||
"""
|
||||
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
||||
used in combination with the [`EncoderDecoderModel`] framework.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
config.is_encoder_decoder = False
|
||||
self.decoder = WhisperDecoder(config)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.decoder.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.decoder.embed_tokens = value
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.decoder(*args, **kwargs)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Whisper decoder with with a language modeling head on top (linear layer with weights tied to the input embeddings).
|
||||
""",
|
||||
WHISPER_START_DOCSTRING,
|
||||
)
|
||||
class WhisperForCausalLM(WhisperPreTrainedModel):
|
||||
_tied_weights_keys = ["proj_out.weight"]
|
||||
main_input_name = "input_ids"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
config.is_encoder_decoder = False
|
||||
self.model = WhisperDecoderWrapper(config)
|
||||
|
||||
self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.proj_out
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.proj_out = new_embeddings
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model.decoder = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||
provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
encoder_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||
if the model is configured as a decoder.
|
||||
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
||||
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
|
||||
tensors are only required when the model is used as a decoder in a Sequence to Sequence model. Contains
|
||||
pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If
|
||||
`past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import WhisperForCausalLM, WhisperForConditionalGeneration, WhisperProcessor
|
||||
>>> import torch
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
|
||||
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2")
|
||||
|
||||
>>> assistant_model = WhisperForCausalLM.from_pretrained("distil-whisper/distil-large-v2")
|
||||
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> sample = ds[0]["audio"]
|
||||
>>> input_features = processor(
|
||||
... sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
|
||||
... ).input_features
|
||||
|
||||
>>> predicted_ids = model.generate(input_features, assistant_model=assistant_model)
|
||||
|
||||
>>> # decode token ids to text
|
||||
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
||||
>>> transcription
|
||||
' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# If the user passed a tuple or `BaseModelOutput` for encoder_outputs, we extract only the hidden states
|
||||
if isinstance(encoder_outputs, (BaseModelOutput, tuple, list)):
|
||||
encoder_outputs = encoder_outputs[0]
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model.decoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_outputs,
|
||||
head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
logits = self.proj_out(outputs[0])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
attention_mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = input_ids.shape[1] - 1
|
||||
|
||||
input_ids = input_ids[:, remove_prefix_length:]
|
||||
|
||||
return {
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past_key_values,
|
||||
"input_ids": input_ids,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
|
||||
|
||||
@@ -8413,6 +8413,13 @@ class WhisperForAudioClassification(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class WhisperForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class WhisperForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ if is_torch_available():
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoModelForVision2Seq,
|
||||
AutoTokenizer,
|
||||
BartForCausalLM,
|
||||
BartForConditionalGeneration,
|
||||
BartTokenizer,
|
||||
GPT2LMHeadModel,
|
||||
@@ -3010,3 +3011,63 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
assistant_encoder_outputs=encoder_outputs,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
def test_assisted_decoding_encoder_decoder_shared_encoder(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
# Bart subclass with a kwarg called foo that distorts the output
|
||||
class FakeBart(BartForConditionalGeneration):
|
||||
def forward(self, input_ids, foo=False, **kwargs):
|
||||
outs = super().forward(input_ids, **kwargs)
|
||||
|
||||
if foo:
|
||||
outs["logits"][:, :, :] = 0.0
|
||||
|
||||
return outs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
inputs["foo"] = foo
|
||||
return inputs
|
||||
|
||||
model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
|
||||
torch_device
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_normal = model.generate(input_ids)
|
||||
self.assertEqual(outputs_normal.shape, (1, 20))
|
||||
|
||||
# Should be different with foo
|
||||
outputs_foo = model.generate(input_ids, foo=True)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
|
||||
|
||||
# Assistant model
|
||||
assistant = BartForCausalLM.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
# If assisted generation passes model_kwargs correctly, should be same as previous
|
||||
outputs_assisted = model.generate(
|
||||
input_ids,
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
# Check that passing encoder_outputs directly also works as expected
|
||||
encoder_outputs = model.get_encoder()(input_ids)
|
||||
|
||||
outputs_assisted = model.generate(
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
encoder_outputs=encoder_outputs,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
@@ -51,6 +51,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
WhisperFeatureExtractor,
|
||||
WhisperForAudioClassification,
|
||||
WhisperForCausalLM,
|
||||
WhisperForConditionalGeneration,
|
||||
WhisperModel,
|
||||
WhisperProcessor,
|
||||
@@ -1990,3 +1991,246 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||
|
||||
|
||||
class WhisperStandaloneDecoderModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=2,
|
||||
is_training=True,
|
||||
use_labels=False,
|
||||
vocab_size=200,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
input_channels=1,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=20,
|
||||
max_source_positions=30,
|
||||
max_target_positions=40,
|
||||
bos_token_id=98,
|
||||
eos_token_id=98,
|
||||
pad_token_id=0,
|
||||
num_mel_bins=80,
|
||||
decoder_start_token_id=85,
|
||||
num_conv_layers=1,
|
||||
suppress_tokens=None,
|
||||
begin_suppress_tokens=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.input_channels = input_channels
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.num_mel_bins = num_mel_bins
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.max_source_positions = max_source_positions
|
||||
self.max_target_positions = max_target_positions
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
self.num_conv_layers = num_conv_layers
|
||||
self.suppress_tokens = suppress_tokens
|
||||
self.begin_suppress_tokens = begin_suppress_tokens
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size)
|
||||
|
||||
decoder_input_ids = torch.tensor(
|
||||
self.batch_size * [[self.decoder_start_token_id, 3, 3, 7, 2]], device=torch_device
|
||||
)
|
||||
|
||||
config = self.get_config()
|
||||
config.is_encoder_decoder = False
|
||||
inputs_dict = prepare_whisper_inputs_dict(
|
||||
config,
|
||||
attention_mask=None,
|
||||
input_features=input_features,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
)
|
||||
|
||||
inputs_dict.pop("input_features")
|
||||
inputs_dict.pop("head_mask")
|
||||
inputs_dict.pop("decoder_head_mask")
|
||||
inputs_dict.pop("cross_attn_head_mask")
|
||||
|
||||
inputs_dict["attention_mask"] = inputs_dict.pop("decoder_attention_mask")
|
||||
inputs_dict["input_ids"] = inputs_dict.pop("decoder_input_ids")
|
||||
return config, inputs_dict
|
||||
|
||||
@property
|
||||
def encoder_seq_length(self):
|
||||
return 5
|
||||
|
||||
@property
|
||||
def seq_length(self):
|
||||
return 5
|
||||
|
||||
def get_config(self):
|
||||
return WhisperConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
encoder_layers=self.num_hidden_layers,
|
||||
decoder_layers=self.num_hidden_layers,
|
||||
encoder_attention_heads=self.num_attention_heads,
|
||||
decoder_attention_heads=self.num_attention_heads,
|
||||
input_channels=self.input_channels,
|
||||
dropout=self.hidden_dropout_prob,
|
||||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
max_source_positions=self.max_source_positions,
|
||||
max_target_positions=self.max_target_positions,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
decoder_ffn_dim=self.hidden_size,
|
||||
encoder_ffn_dim=self.hidden_size,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
suppress_tokens=self.suppress_tokens,
|
||||
begin_suppress_tokens=self.begin_suppress_tokens,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, inputs_dict = self.prepare_config_and_inputs()
|
||||
|
||||
inputs_dict["input_ids"][:, -1] = self.pad_token_id
|
||||
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
config, input_features = self.prepare_config_and_inputs()
|
||||
input_ids = input_features["input_ids"]
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.decoder_seq_length, self.hidden_size])
|
||||
|
||||
return (config, input_ids, encoder_hidden_states)
|
||||
|
||||
def create_and_check_decoder_model_past(self, config, input_ids):
|
||||
config.use_cache = True
|
||||
model = WhisperDecoder(config=config).to(torch_device).eval()
|
||||
# first forward pass
|
||||
outputs = model(input_ids, use_cache=True)
|
||||
outputs_use_cache_conf = model(input_ids)
|
||||
outputs_no_past = model(input_ids, use_cache=False)
|
||||
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
past_key_values = outputs["past_key_values"]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||
|
||||
# test that outputs are equal for slice
|
||||
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||
|
||||
def create_and_check_decoder_model_attention_mask_past(self, config, input_ids):
|
||||
model = WhisperDecoder(config=config).to(torch_device).eval()
|
||||
|
||||
# create attention mask
|
||||
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
|
||||
half_seq_length = input_ids.shape[-1] // 2
|
||||
attn_mask[:, half_seq_length:] = 0
|
||||
|
||||
# first forward pass
|
||||
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
# change a random masked slice from input_ids
|
||||
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
|
||||
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
|
||||
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
|
||||
|
||||
# append to next input_ids and attn_mask
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
|
||||
"last_hidden_state"
|
||||
]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||
|
||||
# test that outputs are equal for slice
|
||||
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||
|
||||
|
||||
@require_torch
|
||||
class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (WhisperDecoder, WhisperForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (WhisperForCausalLM,) if is_torch_available() else ()
|
||||
fx_comptatible = False
|
||||
test_pruning = False
|
||||
is_encoder_decoder = False
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = WhisperStandaloneDecoderModelTester(self, is_training=False)
|
||||
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_decoder_model_past(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
config, inputs_dict = config_and_inputs
|
||||
|
||||
self.model_tester.create_and_check_decoder_model_past(config=config, input_ids=inputs_dict["input_ids"])
|
||||
|
||||
def test_decoder_model_attn_mask_past(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
config, inputs_dict = config_and_inputs
|
||||
|
||||
self.model_tester.create_and_check_decoder_model_attention_mask_past(
|
||||
config=config, input_ids=inputs_dict["input_ids"]
|
||||
)
|
||||
|
||||
@unittest.skip("Generate needs input ids")
|
||||
def test_generate_without_input_ids(self):
|
||||
# generate only works with input ids for whisper
|
||||
pass
|
||||
|
||||
@unittest.skip("Decoder can't keep attention grads")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
# decoder cannot keep gradients
|
||||
return
|
||||
|
||||
@unittest.skip("The model doesn't support fast init from base")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :)
|
||||
def test_left_padding_compatibility(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user