[WHISPER] Small patch (#21307)
* add small patch * update tests, forced decoder ids is not prioritary against generation config * fix two new tests
This commit is contained in:
@@ -699,8 +699,9 @@ def _test_large_generation(in_queue, out_queue, timeout):
|
||||
input_speech = _load_datasamples(1)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
|
||||
@@ -728,26 +729,25 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
|
||||
input_speech = next(iter(ds))["audio"]["array"]
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe"
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
|
||||
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
||||
generated_ids = model.generate(
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " Kimura-san called me."
|
||||
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate"
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
|
||||
@@ -779,10 +779,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
[
|
||||
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
|
||||
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
|
||||
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
|
||||
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
|
||||
[50258, 50259, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404],
|
||||
[50258, 50259, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257],
|
||||
[50258, 50259, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904],
|
||||
[50258, 50259, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439]
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
@@ -791,10 +791,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
' Mr. Quilter is the apostle of the middle classes and we are glad to',
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad",
|
||||
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
|
||||
" He tells us that at this festive season of the year, with Christmas and roast",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all",
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
@@ -945,11 +945,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
torch_device
|
||||
)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
||||
generated_ids = model.generate(
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
@@ -971,26 +968,25 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
torch_device
|
||||
)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe"
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
||||
generated_ids = model.generate(
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " Kimura-san called me."
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate"
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
|
||||
@@ -1009,10 +1005,10 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = torch.tensor(
|
||||
[
|
||||
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
|
||||
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
|
||||
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
|
||||
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
|
||||
[50258, 50259, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404],
|
||||
[50258, 50259, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257],
|
||||
[50258, 50259, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904],
|
||||
[50258, 50259, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439]
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
@@ -1021,10 +1017,10 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad to",
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad",
|
||||
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all",
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user