[Whisper test] Fix some failing tests (#33450)
* Fix failing tensor placement in Whisper * fix long form generation tests * more return_timestamps=True * make fixup * [run_slow] whisper * [run_slow] whisper
This commit is contained in:
@@ -1683,9 +1683,9 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
input_features = input_dict["input_features"]
|
input_features = input_dict["input_features"]
|
||||||
|
|
||||||
labels_length = config.max_target_positions
|
labels_length = config.max_target_positions
|
||||||
labels = torch.ones(1, labels_length, dtype=torch.int64)
|
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)
|
||||||
|
|
||||||
model = model_class(config)
|
model = model_class(config).to(torch_device)
|
||||||
model(input_features=input_features, labels=labels)
|
model(input_features=input_features, labels=labels)
|
||||||
|
|
||||||
def test_labels_sequence_max_length_correct_after_changing_config(self):
|
def test_labels_sequence_max_length_correct_after_changing_config(self):
|
||||||
@@ -1697,9 +1697,9 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
config.max_target_positions += 100
|
config.max_target_positions += 100
|
||||||
|
|
||||||
labels_length = config.max_target_positions
|
labels_length = config.max_target_positions
|
||||||
labels = torch.ones(1, labels_length, dtype=torch.int64)
|
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)
|
||||||
|
|
||||||
model = model_class(config)
|
model = model_class(config).to(torch_device)
|
||||||
model(input_features=input_features, labels=labels)
|
model(input_features=input_features, labels=labels)
|
||||||
|
|
||||||
def test_labels_sequence_max_length_error(self):
|
def test_labels_sequence_max_length_error(self):
|
||||||
@@ -1709,9 +1709,9 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
input_features = input_dict["input_features"]
|
input_features = input_dict["input_features"]
|
||||||
|
|
||||||
labels_length = config.max_target_positions + 1
|
labels_length = config.max_target_positions + 1
|
||||||
labels = torch.ones(1, labels_length, dtype=torch.int64)
|
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)
|
||||||
|
|
||||||
model = model_class(config)
|
model = model_class(config).to(torch_device)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
model(input_features=input_features, labels=labels)
|
model(input_features=input_features, labels=labels)
|
||||||
|
|
||||||
@@ -1719,11 +1719,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config).to(torch_device)
|
||||||
input_features = input_dict["input_features"]
|
input_features = input_dict["input_features"]
|
||||||
|
|
||||||
labels_length = config.max_target_positions + 1
|
labels_length = config.max_target_positions + 1
|
||||||
labels = torch.ones(1, labels_length, dtype=torch.int64)
|
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)
|
||||||
|
|
||||||
new_max_length = config.max_target_positions + 100
|
new_max_length = config.max_target_positions + 100
|
||||||
model.config.max_length = new_max_length
|
model.config.max_length = new_max_length
|
||||||
@@ -2385,7 +2385,9 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
inputs = inputs.to(torch_device)
|
inputs = inputs.to(torch_device)
|
||||||
generate_outputs = model.generate(**inputs, return_segments=True, return_token_timestamps=True)
|
generate_outputs = model.generate(
|
||||||
|
**inputs, return_segments=True, return_token_timestamps=True, return_timestamps=True
|
||||||
|
)
|
||||||
|
|
||||||
token_timestamps_shape = [
|
token_timestamps_shape = [
|
||||||
[segment["token_timestamps"].shape for segment in segment_list]
|
[segment["token_timestamps"].shape for segment in segment_list]
|
||||||
@@ -2550,14 +2552,14 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
).input_features.to(torch_device)
|
).input_features.to(torch_device)
|
||||||
|
|
||||||
# task defaults to transcribe
|
# task defaults to transcribe
|
||||||
sequences = model.generate(input_features)
|
sequences = model.generate(input_features, return_timestamps=True)
|
||||||
|
|
||||||
transcription = processor.batch_decode(sequences)[0]
|
transcription = processor.batch_decode(sequences)[0]
|
||||||
|
|
||||||
assert transcription == " मिर्ची में कितने विबिन्द प्रजातियां हैं? मिर्ची में कितने विबिन्द प्रजातियां हैं?"
|
assert transcription == " मिर्ची में कितने विबिन्द प्रजातियां हैं? मिर्ची में कितने विबिन्द प्रजातियां हैं?"
|
||||||
|
|
||||||
# set task to translate
|
# set task to translate
|
||||||
sequences = model.generate(input_features, task="translate")
|
sequences = model.generate(input_features, task="translate", return_timestamps=True)
|
||||||
transcription = processor.batch_decode(sequences)[0]
|
transcription = processor.batch_decode(sequences)[0]
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@@ -3264,6 +3266,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
"num_beams": 5,
|
"num_beams": 5,
|
||||||
"language": "fr",
|
"language": "fr",
|
||||||
"task": "transcribe",
|
"task": "transcribe",
|
||||||
|
"return_timestamps": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|||||||
Reference in New Issue
Block a user