[Whisper] Pipeline: handle long form generation (#35750)
* handle long form generation * add warning * correct incorrect in place token change * update test to catch edge case * make style * update warning * add doc
This commit is contained in:
@@ -2031,11 +2031,13 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
).input_features
|
||||
input_features = input_features.to(torch_device)
|
||||
|
||||
generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu")
|
||||
generated_ids = model.generate(
|
||||
input_features, max_length=448, return_timestamps=True, condition_on_prev_tokens=True
|
||||
).to("cpu")
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT = torch.tensor([
|
||||
50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50365, 293, 393, 4411, 50430
|
||||
[50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50365, 293, 393, 4411, 50431]
|
||||
])
|
||||
# fmt: on
|
||||
torch.testing.assert_close(generated_ids[0], EXPECTED_OUTPUT)
|
||||
@@ -2078,7 +2080,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
},
|
||||
{
|
||||
"text": (" and can discover"),
|
||||
"timestamp": (28.68, 29.98),
|
||||
"timestamp": (28.68, 30.0),
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user