[Whisper test] Fix some failing tests (#33450)

* Fix failing tensor placement in Whisper

* fix long form generation tests

* more return_timestamps=True

* make fixup

* [run_slow] whisper

* [run_slow] whisper
This commit is contained in:
Yoach Lacombe
2024-09-16 19:05:17 +02:00
committed by GitHub
parent c2d05897bf
commit 98adf24883

View File

@@ -1683,9 +1683,9 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
input_features = input_dict["input_features"]
labels_length = config.max_target_positions
labels = torch.ones(1, labels_length, dtype=torch.int64)
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)
model = model_class(config)
model = model_class(config).to(torch_device)
model(input_features=input_features, labels=labels)
def test_labels_sequence_max_length_correct_after_changing_config(self):
@@ -1697,9 +1697,9 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config.max_target_positions += 100
labels_length = config.max_target_positions
labels = torch.ones(1, labels_length, dtype=torch.int64)
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)
model = model_class(config)
model = model_class(config).to(torch_device)
model(input_features=input_features, labels=labels)
def test_labels_sequence_max_length_error(self):
@@ -1709,9 +1709,9 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
input_features = input_dict["input_features"]
labels_length = config.max_target_positions + 1
labels = torch.ones(1, labels_length, dtype=torch.int64)
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)
model = model_class(config)
model = model_class(config).to(torch_device)
with self.assertRaises(ValueError):
model(input_features=input_features, labels=labels)
@@ -1719,11 +1719,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_generative_model_classes:
model = model_class(config)
model = model_class(config).to(torch_device)
input_features = input_dict["input_features"]
labels_length = config.max_target_positions + 1
labels = torch.ones(1, labels_length, dtype=torch.int64)
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)
new_max_length = config.max_target_positions + 100
model.config.max_length = new_max_length
@@ -2385,7 +2385,9 @@ class WhisperModelIntegrationTests(unittest.TestCase):
)
inputs = inputs.to(torch_device)
generate_outputs = model.generate(**inputs, return_segments=True, return_token_timestamps=True)
generate_outputs = model.generate(
**inputs, return_segments=True, return_token_timestamps=True, return_timestamps=True
)
token_timestamps_shape = [
[segment["token_timestamps"].shape for segment in segment_list]
@@ -2550,14 +2552,14 @@ class WhisperModelIntegrationTests(unittest.TestCase):
).input_features.to(torch_device)
# task defaults to transcribe
sequences = model.generate(input_features)
sequences = model.generate(input_features, return_timestamps=True)
transcription = processor.batch_decode(sequences)[0]
assert transcription == " मिर्ची में कितने विबिन्द प्रजातियां हैं? मिर्ची में कितने विबिन्द प्रजातियां हैं?"
# set task to translate
sequences = model.generate(input_features, task="translate")
sequences = model.generate(input_features, task="translate", return_timestamps=True)
transcription = processor.batch_decode(sequences)[0]
assert (
@@ -3264,6 +3266,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
"num_beams": 5,
"language": "fr",
"task": "transcribe",
"return_timestamps": True,
}
torch.manual_seed(0)