[Whisper] another patch (#21324)
* another patch * fix timestamp test modeling * let it be negative when the token is None
This commit is contained in:
@@ -46,7 +46,6 @@ if is_torch_available():
|
||||
WhisperProcessor,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
|
||||
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
|
||||
|
||||
|
||||
@@ -1077,9 +1076,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||
torch_device
|
||||
)
|
||||
model.config.forced_decoder_ids = [(1, 50259), (2, 50359), (3, 50364)]
|
||||
timestamp_processor = [WhisperTimeStampLogitsProcessor(len(model.config.forced_decoder_ids))]
|
||||
generated_ids = model.generate(input_features, max_length=448, logits_processor=timestamp_processor).to("cpu")
|
||||
|
||||
generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu")
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50257])
|
||||
|
||||
Reference in New Issue
Block a user