[WhisperForCausalLM] Add WhisperForCausalLM for speculative decoding (#27195)
* finish * add tests * fix all tests * [Assistant Decoding] Add test * fix more * better * finish * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * finish --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
f9b4bea0a6
commit
391d14e810
@@ -88,6 +88,11 @@ The original code can be found [here](https://github.com/openai/whisper).
|
|||||||
- forward
|
- forward
|
||||||
- generate
|
- generate
|
||||||
|
|
||||||
|
## WhisperForCausalLM
|
||||||
|
|
||||||
|
[[autodoc]] WhisperForCausalLM
|
||||||
|
- forward
|
||||||
|
|
||||||
## WhisperForAudioClassification
|
## WhisperForAudioClassification
|
||||||
|
|
||||||
[[autodoc]] WhisperForAudioClassification
|
[[autodoc]] WhisperForAudioClassification
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ You can finetune other architectures for causal language modeling following the
|
|||||||
Choose one of the following architectures:
|
Choose one of the following architectures:
|
||||||
|
|
||||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||||
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
|
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3090,6 +3090,7 @@ else:
|
|||||||
[
|
[
|
||||||
"WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"WhisperForAudioClassification",
|
"WhisperForAudioClassification",
|
||||||
|
"WhisperForCausalLM",
|
||||||
"WhisperForConditionalGeneration",
|
"WhisperForConditionalGeneration",
|
||||||
"WhisperModel",
|
"WhisperModel",
|
||||||
"WhisperPreTrainedModel",
|
"WhisperPreTrainedModel",
|
||||||
@@ -6845,6 +6846,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.whisper import (
|
from .models.whisper import (
|
||||||
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
WhisperForAudioClassification,
|
WhisperForAudioClassification,
|
||||||
|
WhisperForCausalLM,
|
||||||
WhisperForConditionalGeneration,
|
WhisperForConditionalGeneration,
|
||||||
WhisperModel,
|
WhisperModel,
|
||||||
WhisperPreTrainedModel,
|
WhisperPreTrainedModel,
|
||||||
|
|||||||
@@ -1626,6 +1626,10 @@ class GenerationMixin:
|
|||||||
if not model_kwargs["use_cache"]:
|
if not model_kwargs["use_cache"]:
|
||||||
raise ValueError("assisted generate requires `use_cache=True`")
|
raise ValueError("assisted generate requires `use_cache=True`")
|
||||||
|
|
||||||
|
assistant_accepts_encoder_outputs = "encoder_outputs" in set(
|
||||||
|
inspect.signature(assistant_model.forward).parameters.keys()
|
||||||
|
)
|
||||||
|
|
||||||
# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
|
# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
|
||||||
if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs:
|
if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs:
|
||||||
assistant_model_kwargs = copy.deepcopy(model_kwargs)
|
assistant_model_kwargs = copy.deepcopy(model_kwargs)
|
||||||
@@ -1637,6 +1641,17 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"]
|
model_kwargs["assistant_encoder_outputs"] = assistant_model_kwargs["encoder_outputs"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
not assistant_model.config.is_encoder_decoder
|
||||||
|
and assistant_accepts_encoder_outputs
|
||||||
|
and "encoder_outputs" in model_kwargs
|
||||||
|
):
|
||||||
|
# some assistants might be assymetric (many more enc layers than dec layers)
|
||||||
|
# encoder-decoder models that share the exact same encoder as the teacher
|
||||||
|
# in this case the assistant only needs to load the light-weight decoder,
|
||||||
|
# but still requires `encoder_outputs` to be passed
|
||||||
|
model_kwargs["assistant_encoder_outputs"] = model_kwargs["encoder_outputs"]
|
||||||
|
|
||||||
# 12. run assisted generate
|
# 12. run assisted generate
|
||||||
return self.assisted_decoding(
|
return self.assisted_decoding(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -4368,6 +4383,11 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
||||||
|
|
||||||
|
# check if assistant model accepts encoder_outputs
|
||||||
|
assistant_accepts_encoder_outputs = "encoder_outputs" in set(
|
||||||
|
inspect.signature(assistant_model.forward).parameters.keys()
|
||||||
|
)
|
||||||
|
|
||||||
# init values
|
# init values
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||||
@@ -4454,9 +4474,13 @@ class GenerationMixin:
|
|||||||
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
encoder_kwargs = {}
|
||||||
|
|
||||||
|
if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
|
||||||
|
encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
|
||||||
|
|
||||||
assistant_model_outputs = assistant_model(
|
assistant_model_outputs = assistant_model(
|
||||||
assist_inputs,
|
assist_inputs, past_key_values=model_kwargs["assistant_past_key_values"], **encoder_kwargs
|
||||||
past_key_values=model_kwargs["assistant_past_key_values"],
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if assistant_model.config.is_encoder_decoder:
|
if assistant_model.config.is_encoder_decoder:
|
||||||
@@ -4465,7 +4489,12 @@ class GenerationMixin:
|
|||||||
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
encoder_outputs=model_kwargs["assistant_encoder_outputs"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assistant_model_outputs = assistant_model(candidate_input_ids)
|
encoder_kwargs = {}
|
||||||
|
|
||||||
|
if assistant_accepts_encoder_outputs and "assistant_encoder_outputs" in model_kwargs:
|
||||||
|
encoder_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
|
||||||
|
|
||||||
|
assistant_model_outputs = assistant_model(candidate_input_ids, **encoder_kwargs)
|
||||||
|
|
||||||
# 1.2. greedily select the next candidate token
|
# 1.2. greedily select the next candidate token
|
||||||
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
|
model_kwargs["assistant_past_key_values"] = assistant_model_outputs.past_key_values
|
||||||
|
|||||||
@@ -438,6 +438,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("speech_to_text_2", "Speech2Text2ForCausalLM"),
|
("speech_to_text_2", "Speech2Text2ForCausalLM"),
|
||||||
("transfo-xl", "TransfoXLLMHeadModel"),
|
("transfo-xl", "TransfoXLLMHeadModel"),
|
||||||
("trocr", "TrOCRForCausalLM"),
|
("trocr", "TrOCRForCausalLM"),
|
||||||
|
("whisper", "WhisperForCausalLM"),
|
||||||
("xglm", "XGLMForCausalLM"),
|
("xglm", "XGLMForCausalLM"),
|
||||||
("xlm", "XLMWithLMHeadModel"),
|
("xlm", "XLMWithLMHeadModel"),
|
||||||
("xlm-prophetnet", "XLMProphetNetForCausalLM"),
|
("xlm-prophetnet", "XLMProphetNetForCausalLM"),
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ except OptionalDependencyNotAvailable:
|
|||||||
else:
|
else:
|
||||||
_import_structure["modeling_whisper"] = [
|
_import_structure["modeling_whisper"] = [
|
||||||
"WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"WhisperForCausalLM",
|
||||||
"WhisperForConditionalGeneration",
|
"WhisperForConditionalGeneration",
|
||||||
"WhisperModel",
|
"WhisperModel",
|
||||||
"WhisperPreTrainedModel",
|
"WhisperPreTrainedModel",
|
||||||
@@ -102,6 +103,7 @@ if TYPE_CHECKING:
|
|||||||
from .modeling_whisper import (
|
from .modeling_whisper import (
|
||||||
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
WhisperForAudioClassification,
|
WhisperForAudioClassification,
|
||||||
|
WhisperForCausalLM,
|
||||||
WhisperForConditionalGeneration,
|
WhisperForConditionalGeneration,
|
||||||
WhisperModel,
|
WhisperModel,
|
||||||
WhisperPreTrainedModel,
|
WhisperPreTrainedModel,
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
|||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
CausalLMOutputWithCrossAttentions,
|
||||||
Seq2SeqLMOutput,
|
Seq2SeqLMOutput,
|
||||||
Seq2SeqModelOutput,
|
Seq2SeqModelOutput,
|
||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
@@ -945,6 +946,8 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
config: WhisperConfig
|
config: WhisperConfig
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
main_input_name = "input_ids"
|
||||||
|
|
||||||
def __init__(self, config: WhisperConfig):
|
def __init__(self, config: WhisperConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
@@ -1028,7 +1031,8 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
||||||
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||||||
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
|
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||||
|
inputs_embeds (`torch.FloatTensor` of
|
||||||
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
|
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
|
||||||
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
|
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
|
||||||
control over how to convert `input_ids` indices into associated vectors than the model's internal
|
control over how to convert `input_ids` indices into associated vectors than the model's internal
|
||||||
@@ -1811,6 +1815,247 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
return timestamps
|
return timestamps
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperDecoderWrapper(WhisperPreTrainedModel):
|
||||||
|
"""
|
||||||
|
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
||||||
|
used in combination with the [`EncoderDecoderModel`] framework.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
config.is_encoder_decoder = False
|
||||||
|
self.decoder = WhisperDecoder(config)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.decoder.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.decoder.embed_tokens = value
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.decoder(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
Whisper decoder with with a language modeling head on top (linear layer with weights tied to the input embeddings).
|
||||||
|
""",
|
||||||
|
WHISPER_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class WhisperForCausalLM(WhisperPreTrainedModel):
|
||||||
|
_tied_weights_keys = ["proj_out.weight"]
|
||||||
|
main_input_name = "input_ids"
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
config.is_encoder_decoder = False
|
||||||
|
self.model = WhisperDecoderWrapper(config)
|
||||||
|
|
||||||
|
self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.proj_out
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.proj_out = new_embeddings
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
|
return self.model.get_input_embeddings()
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.set_input_embeddings(value)
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model.decoder = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model.decoder
|
||||||
|
|
||||||
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
|
provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
encoder_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||||
|
if the model is configured as a decoder.
|
||||||
|
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||||||
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
||||||
|
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
|
||||||
|
tensors are only required when the model is used as a decoder in a Sequence to Sequence model. Contains
|
||||||
|
pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
|
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If
|
||||||
|
`past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||||
|
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||||
|
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||||
|
for more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import WhisperForCausalLM, WhisperForConditionalGeneration, WhisperProcessor
|
||||||
|
>>> import torch
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
|
||||||
|
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
|
||||||
|
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2")
|
||||||
|
|
||||||
|
>>> assistant_model = WhisperForCausalLM.from_pretrained("distil-whisper/distil-large-v2")
|
||||||
|
|
||||||
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
>>> sample = ds[0]["audio"]
|
||||||
|
>>> input_features = processor(
|
||||||
|
... sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
|
||||||
|
... ).input_features
|
||||||
|
|
||||||
|
>>> predicted_ids = model.generate(input_features, assistant_model=assistant_model)
|
||||||
|
|
||||||
|
>>> # decode token ids to text
|
||||||
|
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
||||||
|
>>> transcription
|
||||||
|
' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'
|
||||||
|
```"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# If the user passed a tuple or `BaseModelOutput` for encoder_outputs, we extract only the hidden states
|
||||||
|
if isinstance(encoder_outputs, (BaseModelOutput, tuple, list)):
|
||||||
|
encoder_outputs = encoder_outputs[0]
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model.decoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=encoder_outputs,
|
||||||
|
head_mask=head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.proj_out(outputs[0])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
labels = labels.to(logits.device)
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
attention_mask=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if past_key_values is not None:
|
||||||
|
past_length = past_key_values[0][0].shape[2]
|
||||||
|
|
||||||
|
# Some generation methods already pass only the last input ID
|
||||||
|
if input_ids.shape[1] > past_length:
|
||||||
|
remove_prefix_length = past_length
|
||||||
|
else:
|
||||||
|
# Default to old behavior: keep only final ID
|
||||||
|
remove_prefix_length = input_ids.shape[1] - 1
|
||||||
|
|
||||||
|
input_ids = input_ids[:, remove_prefix_length:]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"encoder_outputs": encoder_outputs,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(past_key_values, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past_key_values:
|
||||||
|
reordered_past += (
|
||||||
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||||
|
)
|
||||||
|
return reordered_past
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
|
Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
|
||||||
|
|||||||
@@ -8413,6 +8413,13 @@ class WhisperForAudioClassification(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperForCausalLM(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class WhisperForConditionalGeneration(metaclass=DummyObject):
|
class WhisperForConditionalGeneration(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ if is_torch_available():
|
|||||||
AutoModelForSpeechSeq2Seq,
|
AutoModelForSpeechSeq2Seq,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
BartForCausalLM,
|
||||||
BartForConditionalGeneration,
|
BartForConditionalGeneration,
|
||||||
BartTokenizer,
|
BartTokenizer,
|
||||||
GPT2LMHeadModel,
|
GPT2LMHeadModel,
|
||||||
@@ -3010,3 +3011,63 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
assistant_encoder_outputs=encoder_outputs,
|
assistant_encoder_outputs=encoder_outputs,
|
||||||
)
|
)
|
||||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||||
|
|
||||||
|
def test_assisted_decoding_encoder_decoder_shared_encoder(self):
|
||||||
|
# PT-only test: TF doesn't support assisted decoding yet.
|
||||||
|
# Bart subclass with a kwarg called foo that distorts the output
|
||||||
|
class FakeBart(BartForConditionalGeneration):
|
||||||
|
def forward(self, input_ids, foo=False, **kwargs):
|
||||||
|
outs = super().forward(input_ids, **kwargs)
|
||||||
|
|
||||||
|
if foo:
|
||||||
|
outs["logits"][:, :, :] = 0.0
|
||||||
|
|
||||||
|
return outs
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
|
||||||
|
kwargs["encoder_outputs"] = encoder_outputs
|
||||||
|
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||||
|
|
||||||
|
inputs["foo"] = foo
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
|
||||||
|
|
||||||
|
text = "Hello world"
|
||||||
|
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||||
|
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||||
|
|
||||||
|
# Traditional way of generating text
|
||||||
|
outputs_normal = model.generate(input_ids)
|
||||||
|
self.assertEqual(outputs_normal.shape, (1, 20))
|
||||||
|
|
||||||
|
# Should be different with foo
|
||||||
|
outputs_foo = model.generate(input_ids, foo=True)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
|
||||||
|
|
||||||
|
# Assistant model
|
||||||
|
assistant = BartForCausalLM.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
# If assisted generation passes model_kwargs correctly, should be same as previous
|
||||||
|
outputs_assisted = model.generate(
|
||||||
|
input_ids,
|
||||||
|
foo=True,
|
||||||
|
assistant_model=assistant,
|
||||||
|
)
|
||||||
|
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||||
|
|
||||||
|
# Check that passing encoder_outputs directly also works as expected
|
||||||
|
encoder_outputs = model.get_encoder()(input_ids)
|
||||||
|
|
||||||
|
outputs_assisted = model.generate(
|
||||||
|
foo=True,
|
||||||
|
assistant_model=assistant,
|
||||||
|
encoder_outputs=encoder_outputs,
|
||||||
|
)
|
||||||
|
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
WhisperFeatureExtractor,
|
WhisperFeatureExtractor,
|
||||||
WhisperForAudioClassification,
|
WhisperForAudioClassification,
|
||||||
|
WhisperForCausalLM,
|
||||||
WhisperForConditionalGeneration,
|
WhisperForConditionalGeneration,
|
||||||
WhisperModel,
|
WhisperModel,
|
||||||
WhisperProcessor,
|
WhisperProcessor,
|
||||||
@@ -1990,3 +1991,246 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
|||||||
|
|
||||||
self.assertEqual(fx_keys, pt_keys)
|
self.assertEqual(fx_keys, pt_keys)
|
||||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperStandaloneDecoderModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=2,
|
||||||
|
is_training=True,
|
||||||
|
use_labels=False,
|
||||||
|
vocab_size=200,
|
||||||
|
hidden_size=16,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=4,
|
||||||
|
input_channels=1,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=20,
|
||||||
|
max_source_positions=30,
|
||||||
|
max_target_positions=40,
|
||||||
|
bos_token_id=98,
|
||||||
|
eos_token_id=98,
|
||||||
|
pad_token_id=0,
|
||||||
|
num_mel_bins=80,
|
||||||
|
decoder_start_token_id=85,
|
||||||
|
num_conv_layers=1,
|
||||||
|
suppress_tokens=None,
|
||||||
|
begin_suppress_tokens=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.input_channels = input_channels
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.num_mel_bins = num_mel_bins
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.max_source_positions = max_source_positions
|
||||||
|
self.max_target_positions = max_target_positions
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.decoder_start_token_id = decoder_start_token_id
|
||||||
|
self.num_conv_layers = num_conv_layers
|
||||||
|
self.suppress_tokens = suppress_tokens
|
||||||
|
self.begin_suppress_tokens = begin_suppress_tokens
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
decoder_input_ids = torch.tensor(
|
||||||
|
self.batch_size * [[self.decoder_start_token_id, 3, 3, 7, 2]], device=torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
config.is_encoder_decoder = False
|
||||||
|
inputs_dict = prepare_whisper_inputs_dict(
|
||||||
|
config,
|
||||||
|
attention_mask=None,
|
||||||
|
input_features=input_features,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs_dict.pop("input_features")
|
||||||
|
inputs_dict.pop("head_mask")
|
||||||
|
inputs_dict.pop("decoder_head_mask")
|
||||||
|
inputs_dict.pop("cross_attn_head_mask")
|
||||||
|
|
||||||
|
inputs_dict["attention_mask"] = inputs_dict.pop("decoder_attention_mask")
|
||||||
|
inputs_dict["input_ids"] = inputs_dict.pop("decoder_input_ids")
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoder_seq_length(self):
|
||||||
|
return 5
|
||||||
|
|
||||||
|
@property
|
||||||
|
def seq_length(self):
|
||||||
|
return 5
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return WhisperConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.hidden_size,
|
||||||
|
encoder_layers=self.num_hidden_layers,
|
||||||
|
decoder_layers=self.num_hidden_layers,
|
||||||
|
encoder_attention_heads=self.num_attention_heads,
|
||||||
|
decoder_attention_heads=self.num_attention_heads,
|
||||||
|
input_channels=self.input_channels,
|
||||||
|
dropout=self.hidden_dropout_prob,
|
||||||
|
attention_dropout=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
max_source_positions=self.max_source_positions,
|
||||||
|
max_target_positions=self.max_target_positions,
|
||||||
|
eos_token_id=self.eos_token_id,
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
decoder_ffn_dim=self.hidden_size,
|
||||||
|
encoder_ffn_dim=self.hidden_size,
|
||||||
|
decoder_start_token_id=self.decoder_start_token_id,
|
||||||
|
suppress_tokens=self.suppress_tokens,
|
||||||
|
begin_suppress_tokens=self.begin_suppress_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config, inputs_dict = self.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
inputs_dict["input_ids"][:, -1] = self.pad_token_id
|
||||||
|
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_decoder(self):
|
||||||
|
config, input_features = self.prepare_config_and_inputs()
|
||||||
|
input_ids = input_features["input_ids"]
|
||||||
|
encoder_hidden_states = floats_tensor([self.batch_size, self.decoder_seq_length, self.hidden_size])
|
||||||
|
|
||||||
|
return (config, input_ids, encoder_hidden_states)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past(self, config, input_ids):
|
||||||
|
config.use_cache = True
|
||||||
|
model = WhisperDecoder(config=config).to(torch_device).eval()
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(input_ids, use_cache=True)
|
||||||
|
outputs_use_cache_conf = model(input_ids)
|
||||||
|
outputs_no_past = model(input_ids, use_cache=False)
|
||||||
|
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||||
|
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||||
|
|
||||||
|
past_key_values = outputs["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_attention_mask_past(self, config, input_ids):
|
||||||
|
model = WhisperDecoder(config=config).to(torch_device).eval()
|
||||||
|
|
||||||
|
# create attention mask
|
||||||
|
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
|
|
||||||
|
half_seq_length = input_ids.shape[-1] // 2
|
||||||
|
attn_mask[:, half_seq_length:] = 0
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||||
|
|
||||||
|
# change a random masked slice from input_ids
|
||||||
|
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
|
||||||
|
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
|
||||||
|
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
|
||||||
|
|
||||||
|
# append to next input_ids and attn_mask
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
attn_mask = torch.cat(
|
||||||
|
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get two different outputs
|
||||||
|
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||||
|
output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
|
||||||
|
"last_hidden_state"
|
||||||
|
]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (WhisperDecoder, WhisperForCausalLM) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (WhisperForCausalLM,) if is_torch_available() else ()
|
||||||
|
fx_comptatible = False
|
||||||
|
test_pruning = False
|
||||||
|
is_encoder_decoder = False
|
||||||
|
test_missing_keys = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = WhisperStandaloneDecoderModelTester(self, is_training=False)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_decoder_model_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
config, inputs_dict = config_and_inputs
|
||||||
|
|
||||||
|
self.model_tester.create_and_check_decoder_model_past(config=config, input_ids=inputs_dict["input_ids"])
|
||||||
|
|
||||||
|
def test_decoder_model_attn_mask_past(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
config, inputs_dict = config_and_inputs
|
||||||
|
|
||||||
|
self.model_tester.create_and_check_decoder_model_attention_mask_past(
|
||||||
|
config=config, input_ids=inputs_dict["input_ids"]
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skip("Generate needs input ids")
|
||||||
|
def test_generate_without_input_ids(self):
|
||||||
|
# generate only works with input ids for whisper
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Decoder can't keep attention grads")
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
# decoder cannot keep gradients
|
||||||
|
return
|
||||||
|
|
||||||
|
@unittest.skip("The model doesn't support fast init from base")
|
||||||
|
def test_save_load_fast_init_from_base(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :)
|
||||||
|
def test_left_padding_compatibility(self):
|
||||||
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user