[Whisper test] Fix some failing tests (#33450)
* Fix failing tensor placement in Whisper * fix long form generation tests * more return_timestamps=True * make fixup * [run_slow] whisper * [run_slow] whisper
This commit is contained in:
@@ -1683,9 +1683,9 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
input_features = input_dict["input_features"]
|
||||
|
||||
labels_length = config.max_target_positions
|
||||
labels = torch.ones(1, labels_length, dtype=torch.int64)
|
||||
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)
|
||||
|
||||
model = model_class(config)
|
||||
model = model_class(config).to(torch_device)
|
||||
model(input_features=input_features, labels=labels)
|
||||
|
||||
def test_labels_sequence_max_length_correct_after_changing_config(self):
|
||||
@@ -1697,9 +1697,9 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
config.max_target_positions += 100
|
||||
|
||||
labels_length = config.max_target_positions
|
||||
labels = torch.ones(1, labels_length, dtype=torch.int64)
|
||||
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)
|
||||
|
||||
model = model_class(config)
|
||||
model = model_class(config).to(torch_device)
|
||||
model(input_features=input_features, labels=labels)
|
||||
|
||||
def test_labels_sequence_max_length_error(self):
|
||||
@@ -1709,9 +1709,9 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
input_features = input_dict["input_features"]
|
||||
|
||||
labels_length = config.max_target_positions + 1
|
||||
labels = torch.ones(1, labels_length, dtype=torch.int64)
|
||||
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)
|
||||
|
||||
model = model_class(config)
|
||||
model = model_class(config).to(torch_device)
|
||||
with self.assertRaises(ValueError):
|
||||
model(input_features=input_features, labels=labels)
|
||||
|
||||
@@ -1719,11 +1719,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
model = model_class(config).to(torch_device)
|
||||
input_features = input_dict["input_features"]
|
||||
|
||||
labels_length = config.max_target_positions + 1
|
||||
labels = torch.ones(1, labels_length, dtype=torch.int64)
|
||||
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)
|
||||
|
||||
new_max_length = config.max_target_positions + 100
|
||||
model.config.max_length = new_max_length
|
||||
@@ -2385,7 +2385,9 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
inputs = inputs.to(torch_device)
|
||||
generate_outputs = model.generate(**inputs, return_segments=True, return_token_timestamps=True)
|
||||
generate_outputs = model.generate(
|
||||
**inputs, return_segments=True, return_token_timestamps=True, return_timestamps=True
|
||||
)
|
||||
|
||||
token_timestamps_shape = [
|
||||
[segment["token_timestamps"].shape for segment in segment_list]
|
||||
@@ -2550,14 +2552,14 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
).input_features.to(torch_device)
|
||||
|
||||
# task defaults to transcribe
|
||||
sequences = model.generate(input_features)
|
||||
sequences = model.generate(input_features, return_timestamps=True)
|
||||
|
||||
transcription = processor.batch_decode(sequences)[0]
|
||||
|
||||
assert transcription == " मिर्ची में कितने विबिन्द प्रजातियां हैं? मिर्ची में कितने विबिन्द प्रजातियां हैं?"
|
||||
|
||||
# set task to translate
|
||||
sequences = model.generate(input_features, task="translate")
|
||||
sequences = model.generate(input_features, task="translate", return_timestamps=True)
|
||||
transcription = processor.batch_decode(sequences)[0]
|
||||
|
||||
assert (
|
||||
@@ -3264,6 +3266,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
"num_beams": 5,
|
||||
"language": "fr",
|
||||
"task": "transcribe",
|
||||
"return_timestamps": True,
|
||||
}
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user