feat: Whisper prompting (#22496)
* initial working additions * clean and rename, add cond stripping initial prompt to decode * cleanup, edit create_initial_prompt_ids, add tests * repo consistency, flip order of conditional * fix error, move the processor fn to the tokenizer * repo consistency, update test ids to corresponding tokenizer * use convert_tokens_to_ids not get_vocab... * use actual conditional in generate * make sytle * initial address comments * initial working add new params to pipeline * first draft of sequential generation for condition_on_previous_text * add/update tests, make compatible with timestamps * make compatible with diff. input kwargs and max length * add None check * add temperature check * flip temp check operand * refocusing to prev pr scope * remove the params too * make style * edits, move max length incorporating prompt to whisper * address comments * remove asr pipeline prompt decoding, fix indexing * address comments (more tests, validate prompt) * un-comment out tests (from debug) * remove old comment * address comments * fix typo * remove timestamp token from test * make style * cleanup * copy method to fast tokenizer, set max_new_tokens for test * prompt_ids type just pt * address Amy's comments * make style
This commit is contained in:
@@ -34,7 +34,12 @@ from ...modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import (
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
logging,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
from .configuration_whisper import WhisperConfig
|
from .configuration_whisper import WhisperConfig
|
||||||
from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
|
from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
|
||||||
|
|
||||||
@@ -1464,6 +1469,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
task=None,
|
task=None,
|
||||||
language=None,
|
language=None,
|
||||||
is_multilingual=None,
|
is_multilingual=None,
|
||||||
|
prompt_ids: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1521,6 +1527,11 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
|
find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
|
||||||
is_multilingual (`bool`, *optional*):
|
is_multilingual (`bool`, *optional*):
|
||||||
Whether or not the model is multilingual.
|
Whether or not the model is multilingual.
|
||||||
|
prompt_ids (`torch.Tensor`, *optional*):
|
||||||
|
Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
|
||||||
|
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
|
||||||
|
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
|
||||||
|
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
|
||||||
kwargs:
|
kwargs:
|
||||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||||
@@ -1567,8 +1578,21 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
if task is not None:
|
if task is not None:
|
||||||
generation_config.task = task
|
generation_config.task = task
|
||||||
|
|
||||||
forced_decoder_ids = []
|
forced_decoder_ids = None
|
||||||
if task is not None or language is not None:
|
|
||||||
|
# Legacy code for backward compatibility
|
||||||
|
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
|
||||||
|
forced_decoder_ids = self.config.forced_decoder_ids
|
||||||
|
elif (
|
||||||
|
hasattr(self.generation_config, "forced_decoder_ids")
|
||||||
|
and self.generation_config.forced_decoder_ids is not None
|
||||||
|
):
|
||||||
|
forced_decoder_ids = self.generation_config.forced_decoder_ids
|
||||||
|
else:
|
||||||
|
forced_decoder_ids = kwargs.get("forced_decoder_ids", None)
|
||||||
|
|
||||||
|
if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None):
|
||||||
|
forced_decoder_ids = []
|
||||||
if hasattr(generation_config, "language"):
|
if hasattr(generation_config, "language"):
|
||||||
if generation_config.language in generation_config.lang_to_id.keys():
|
if generation_config.language in generation_config.lang_to_id.keys():
|
||||||
language_token = generation_config.language
|
language_token = generation_config.language
|
||||||
@@ -1593,27 +1617,48 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
|
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
|
||||||
)
|
)
|
||||||
else:
|
elif hasattr(generation_config, "task_to_id"):
|
||||||
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
|
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
|
||||||
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
|
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
|
||||||
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
|
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
|
||||||
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
|
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
|
||||||
|
|
||||||
# Legacy code for backward compatibility
|
if forced_decoder_ids is not None:
|
||||||
elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
|
generation_config.forced_decoder_ids = forced_decoder_ids
|
||||||
forced_decoder_ids = self.config.forced_decoder_ids
|
|
||||||
elif (
|
if prompt_ids is not None:
|
||||||
hasattr(self.generation_config, "forced_decoder_ids")
|
if kwargs.get("decoder_start_token_id") is not None:
|
||||||
and self.generation_config.forced_decoder_ids is not None
|
raise ValueError(
|
||||||
):
|
"When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten."
|
||||||
forced_decoder_ids = self.generation_config.forced_decoder_ids
|
)
|
||||||
|
prompt_ids = prompt_ids.tolist()
|
||||||
|
decoder_start_token_id, *text_prompt_ids = prompt_ids
|
||||||
|
# Set the decoder_start_token_id to <|startofprev|>
|
||||||
|
kwargs.update({"decoder_start_token_id": decoder_start_token_id})
|
||||||
|
|
||||||
|
# Update the max generation length to include the prompt
|
||||||
|
specified_max_length = kwargs.pop("max_new_tokens", None) or kwargs.pop("max_length", None)
|
||||||
|
default_max_length = generation_config.max_new_tokens or generation_config.max_length
|
||||||
|
non_prompt_max_length = specified_max_length or default_max_length
|
||||||
|
kwargs["max_new_tokens"] = non_prompt_max_length + len(text_prompt_ids)
|
||||||
|
|
||||||
|
# Reformat the forced_decoder_ids to incorporate the prompt
|
||||||
|
non_prompt_forced_decoder_ids = (
|
||||||
|
kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids
|
||||||
|
)
|
||||||
|
forced_decoder_ids = [
|
||||||
|
# Slicing the text prompt ids in a manner consistent with the OpenAI implementation
|
||||||
|
# to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
|
||||||
|
*text_prompt_ids[-self.config.max_length // 2 - 1 :],
|
||||||
|
generation_config.decoder_start_token_id,
|
||||||
|
*[token for _rank, token in non_prompt_forced_decoder_ids],
|
||||||
|
]
|
||||||
|
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
|
||||||
|
generation_config.forced_decoder_ids = forced_decoder_ids
|
||||||
|
|
||||||
if generation_config.return_timestamps:
|
if generation_config.return_timestamps:
|
||||||
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
|
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
|
||||||
|
|
||||||
if len(forced_decoder_ids) > 0:
|
|
||||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
|
||||||
|
|
||||||
return super().generate(
|
return super().generate(
|
||||||
inputs,
|
inputs,
|
||||||
generation_config,
|
generation_config,
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
Speech processor class for Whisper
|
Speech processor class for Whisper
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
from ...processing_utils import ProcessorMixin
|
from ...processing_utils import ProcessorMixin
|
||||||
|
|
||||||
|
|
||||||
@@ -91,3 +92,6 @@ class WhisperProcessor(ProcessorMixin):
|
|||||||
the docstring of this method for more information.
|
the docstring of this method for more information.
|
||||||
"""
|
"""
|
||||||
return self.tokenizer.decode(*args, **kwargs)
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_prompt_ids(self, text: str, return_tensors="np"):
|
||||||
|
return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors)
|
||||||
|
|||||||
@@ -606,6 +606,11 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
) -> str:
|
) -> str:
|
||||||
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
||||||
|
|
||||||
|
if skip_special_tokens:
|
||||||
|
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
|
||||||
|
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
|
||||||
|
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
|
||||||
|
|
||||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
# To avoid mixing byte-level and unicode for byte-level BPT
|
# To avoid mixing byte-level and unicode for byte-level BPT
|
||||||
@@ -714,6 +719,31 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
time_precision=time_precision,
|
time_precision=time_precision,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_prompt_ids(self, text: str, return_tensors="np"):
|
||||||
|
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
|
||||||
|
batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False)
|
||||||
|
|
||||||
|
# Check for special tokens
|
||||||
|
prompt_text_ids = batch_encoding["input_ids"][1:]
|
||||||
|
special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None)
|
||||||
|
if special_token_id is not None:
|
||||||
|
token = self.convert_ids_to_tokens(special_token_id)
|
||||||
|
raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.")
|
||||||
|
|
||||||
|
batch_encoding.convert_to_tensors(tensor_type=return_tensors)
|
||||||
|
return batch_encoding["input_ids"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
|
||||||
|
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
|
||||||
|
if has_prompt:
|
||||||
|
if decoder_start_token_id in token_ids:
|
||||||
|
return token_ids[token_ids.index(decoder_start_token_id) :]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
|
||||||
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
|
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -312,6 +312,11 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
def _decode(self, *args, normalize: bool = False, **kwargs) -> str:
|
def _decode(self, *args, normalize: bool = False, **kwargs) -> str:
|
||||||
|
if kwargs["skip_special_tokens"]:
|
||||||
|
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
|
||||||
|
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
|
||||||
|
kwargs["token_ids"] = self._strip_prompt(kwargs["token_ids"], prompt_token_id, decoder_start_token_id)
|
||||||
|
|
||||||
text = super()._decode(*args, **kwargs)
|
text = super()._decode(*args, **kwargs)
|
||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
@@ -485,3 +490,30 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
return_language=return_language,
|
return_language=return_language,
|
||||||
time_precision=time_precision,
|
time_precision=time_precision,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.get_prompt_ids
|
||||||
|
def get_prompt_ids(self, text: str, return_tensors="np"):
|
||||||
|
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
|
||||||
|
batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False)
|
||||||
|
|
||||||
|
# Check for special tokens
|
||||||
|
prompt_text_ids = batch_encoding["input_ids"][1:]
|
||||||
|
special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None)
|
||||||
|
if special_token_id is not None:
|
||||||
|
token = self.convert_ids_to_tokens(special_token_id)
|
||||||
|
raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.")
|
||||||
|
|
||||||
|
batch_encoding.convert_to_tensors(tensor_type=return_tensors)
|
||||||
|
return batch_encoding["input_ids"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._strip_prompt
|
||||||
|
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
|
||||||
|
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
|
||||||
|
if has_prompt:
|
||||||
|
if decoder_start_token_id in token_ids:
|
||||||
|
return token_ids[token_ids.index(decoder_start_token_id) :]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return token_ids
|
||||||
|
|||||||
@@ -1013,6 +1013,48 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
|
encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
|
||||||
self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))
|
self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))
|
||||||
|
|
||||||
|
def test_generate_with_prompt_ids_and_task_and_language(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = WhisperForConditionalGeneration(config).eval().to(torch_device)
|
||||||
|
input_features = input_dict["input_features"]
|
||||||
|
prompt_ids = np.arange(5)
|
||||||
|
language = "<|de|>"
|
||||||
|
task = "translate"
|
||||||
|
lang_id = 6
|
||||||
|
task_id = 7
|
||||||
|
model.generation_config.__setattr__("lang_to_id", {language: lang_id})
|
||||||
|
model.generation_config.__setattr__("task_to_id", {task: task_id})
|
||||||
|
|
||||||
|
output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids)
|
||||||
|
|
||||||
|
expected_output_start = [
|
||||||
|
*prompt_ids.tolist(),
|
||||||
|
model.generation_config.decoder_start_token_id,
|
||||||
|
lang_id,
|
||||||
|
task_id,
|
||||||
|
]
|
||||||
|
for row in output.tolist():
|
||||||
|
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
|
||||||
|
|
||||||
|
def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = WhisperForConditionalGeneration(config).eval().to(torch_device)
|
||||||
|
input_features = input_dict["input_features"]
|
||||||
|
prompt_ids = np.asarray(range(5))
|
||||||
|
forced_decoder_ids = [(1, 6), (2, 7), (3, 8)]
|
||||||
|
|
||||||
|
output = model.generate(
|
||||||
|
input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_output_start = [
|
||||||
|
*prompt_ids.tolist(),
|
||||||
|
model.generation_config.decoder_start_token_id,
|
||||||
|
*[token for _rank, token in forced_decoder_ids],
|
||||||
|
]
|
||||||
|
for row in output.tolist():
|
||||||
|
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
@@ -1429,6 +1471,60 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_generate_with_prompt_ids(self):
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||||
|
model.to(torch_device)
|
||||||
|
input_speech = self._load_datasamples(4)[-1:]
|
||||||
|
input_features = processor(input_speech, return_tensors="pt").input_features
|
||||||
|
|
||||||
|
output_without_prompt = model.generate(input_features)
|
||||||
|
prompt_ids = processor.get_prompt_ids("Leighton")
|
||||||
|
output_with_prompt = model.generate(input_features, prompt_ids=prompt_ids)
|
||||||
|
|
||||||
|
expected_without_prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>"
|
||||||
|
expected_with_prompt = "<|startofprev|> Leighton<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Leighton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>"
|
||||||
|
self.assertEqual(processor.decode(output_without_prompt[0]), expected_without_prompt)
|
||||||
|
self.assertEqual(processor.decode(output_with_prompt[0]), expected_with_prompt)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||||
|
model.to(torch_device)
|
||||||
|
input_speech = self._load_datasamples(1)
|
||||||
|
input_features = processor(input_speech, return_tensors="pt").input_features
|
||||||
|
task = "translate"
|
||||||
|
language = "de"
|
||||||
|
expected_tokens = [f"<|{task}|>", f"<|{language}|>"]
|
||||||
|
prompt = "test prompt"
|
||||||
|
prompt_ids = processor.get_prompt_ids(prompt)
|
||||||
|
|
||||||
|
output = model.generate(input_features, task=task, language=language, prompt_ids=prompt_ids)
|
||||||
|
text = processor.decode(output[0])
|
||||||
|
|
||||||
|
self.assertTrue(prompt in text)
|
||||||
|
self.assertTrue(all([token in text for token in expected_tokens]))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_generate_with_prompt_ids_and_no_non_prompt_forced_decoder_ids(self):
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
model.to(torch_device)
|
||||||
|
input_speech = self._load_datasamples(1)
|
||||||
|
input_features = processor(input_speech, return_tensors="pt").input_features
|
||||||
|
prompt = "test prompt"
|
||||||
|
prompt_ids = processor.get_prompt_ids(prompt)
|
||||||
|
|
||||||
|
model.generation_config.forced_decoder_ids = None
|
||||||
|
model.config.forced_decoder_ids = None
|
||||||
|
|
||||||
|
output = model.generate(input_features, prompt_ids=prompt_ids, return_timestamps=True)
|
||||||
|
text = processor.decode(output[0])
|
||||||
|
|
||||||
|
self.assertTrue(prompt in text)
|
||||||
|
|
||||||
|
|
||||||
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
||||||
if head_mask is None:
|
if head_mask is None:
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from transformers import WhisperTokenizer, is_speech_available
|
from transformers import WhisperTokenizer, is_speech_available
|
||||||
from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio
|
from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio
|
||||||
|
|
||||||
@@ -146,3 +148,32 @@ class WhisperProcessorTest(unittest.TestCase):
|
|||||||
|
|
||||||
expected_ids = [TRANSCRIBE, NOTIMESTAMPS]
|
expected_ids = [TRANSCRIBE, NOTIMESTAMPS]
|
||||||
self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids)
|
self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids)
|
||||||
|
|
||||||
|
def test_get_prompt_ids(self):
|
||||||
|
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||||
|
prompt_ids = processor.get_prompt_ids("Mr. Quilter")
|
||||||
|
decoded_prompt = processor.tokenizer.decode(prompt_ids)
|
||||||
|
|
||||||
|
self.assertListEqual(prompt_ids.tolist(), [50360, 1770, 13, 2264, 346, 353])
|
||||||
|
self.assertEqual(decoded_prompt, "<|startofprev|> Mr. Quilter")
|
||||||
|
|
||||||
|
def test_empty_get_prompt_ids(self):
|
||||||
|
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||||
|
prompt_ids = processor.get_prompt_ids("")
|
||||||
|
decoded_prompt = processor.tokenizer.decode(prompt_ids)
|
||||||
|
|
||||||
|
self.assertListEqual(prompt_ids.tolist(), [50360, 220])
|
||||||
|
self.assertEqual(decoded_prompt, "<|startofprev|> ")
|
||||||
|
|
||||||
|
def test_get_prompt_ids_with_special_tokens(self):
|
||||||
|
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||||
|
|
||||||
|
def _test_prompt_error_raised_helper(prompt, special_token):
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
processor.get_prompt_ids(prompt)
|
||||||
|
expected = f"Encountered text in the prompt corresponding to disallowed special token: {special_token}."
|
||||||
|
self.assertEqual(expected, str(excinfo.value))
|
||||||
|
|
||||||
|
_test_prompt_error_raised_helper("<|startofprev|> test", "<|startofprev|>")
|
||||||
|
_test_prompt_error_raised_helper("test <|notimestamps|>", "<|notimestamps|>")
|
||||||
|
_test_prompt_error_raised_helper("test <|zh|> test <|transcribe|>", "<|zh|>")
|
||||||
|
|||||||
@@ -194,6 +194,25 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
merge = _find_longest_common_sequence([seq1, seq2, seq3])
|
merge = _find_longest_common_sequence([seq1, seq2, seq3])
|
||||||
self.assertEqual(merge, [1, 2, 3, 4, 5, 6, 7, 8])
|
self.assertEqual(merge, [1, 2, 3, 4, 5, 6, 7, 8])
|
||||||
|
|
||||||
|
def test_skip_special_tokens_skips_prompt_ids(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
rust_tokenizer = self.get_rust_tokenizer()
|
||||||
|
# fmt: off
|
||||||
|
encoded_input = [
|
||||||
|
50361, 2221, 13, 2326, 388, 391, 50258, 50259, 50359,
|
||||||
|
50363, 1282, 264, 2674, 9156, 295, 1523, 11, 2221, 13,
|
||||||
|
2326, 388, 391, 13657, 365, 2681, 21296, 17711, 13, 50257,
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
expected_with_special_tokens = "<|startofprev|> Mr. Quilter<|startoftranscript|><|en|><|transcribe|><|notimestamps|> On the general principles of art, Mr. Quilter writes with equal lucidity.<|endoftext|>"
|
||||||
|
expected_without_special_tokens = " On the general principles of art, Mr. Quilter writes with equal lucidity."
|
||||||
|
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
|
||||||
|
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens)
|
||||||
|
self.assertEqual(rust_tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
|
||||||
|
self.assertEqual(
|
||||||
|
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||||
checkpoint_name = "openai/whisper-small.en"
|
checkpoint_name = "openai/whisper-small.en"
|
||||||
|
|||||||
Reference in New Issue
Block a user