Add validation for maximum sequence length in modeling_whisper.py (#33196)
* Add validation for maximum sequence length in modeling_whisper.py Added a validation check to ensure that the sequence length of labels does not exceed the maximum allowed length of 448 tokens. If the sequence length exceeds this limit, a ValueError is raised with a descriptive error message. This change prevents the model from encountering errors or unexpected behavior due to excessively long sequences during training or fine-tuning, ensuring consistent input dimensions and improving overall robustness. * Change exception message in src/transformers/models/whisper/modeling_whisper.py The exception message is for whisper's label's sequence max length. Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> * Change 448 to config.max_target_positions in src/transformers/models/whisper/modeling_whisper.py It's for whisper's config.max_target_positions. Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> * Change method's documentation in src/transformers/models/whisper/modeling_whisper.py * Add test for maximum label's sequence length in test_modeling_whisper.py * Add self to modeling_whisper.py * Update test_modeling_whisper.py with respect to automatic validations * Update modeling_whisper.py with respect to ci/circleci: check_code_quality * Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality * Update test_modeling_whisper.py with respect to ci/circleci: tests_generate * Update test_modeling_whisper.py with respect to ci/circleci: tests_generate * Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality * Separate test_labels_sequence_max_length tests in test_modeling_whisper.py * Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality * Remove assert from test_modeling_whisper.py * Add max_target_positions to WhisperModelTester in test_modeling_whisper.py * Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality * Update test_modeling_whisper.py with respect to ci/circleci: tests_generate * Update test_modeling_whisper.py * Change test_labels_sequence_max_length_error_after_changing_config in test_modeling_whisper.py * Change self.config.max_target_positions to self.max_target_positions modeling_whisper.py * Add new tests in test_modeling_whisper.py * Update test_modeling_whisper.py --------- Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
363301f221
commit
3314fe1760
@@ -1676,6 +1676,63 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
def test_labels_sequence_max_length_correct(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
input_features = input_dict["input_features"]
|
||||
|
||||
labels_length = config.max_target_positions
|
||||
labels = torch.ones(1, labels_length, dtype=torch.int64)
|
||||
|
||||
model = model_class(config)
|
||||
model(input_features=input_features, labels=labels)
|
||||
|
||||
def test_labels_sequence_max_length_correct_after_changing_config(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
input_features = input_dict["input_features"]
|
||||
|
||||
config.max_target_positions += 100
|
||||
|
||||
labels_length = config.max_target_positions
|
||||
labels = torch.ones(1, labels_length, dtype=torch.int64)
|
||||
|
||||
model = model_class(config)
|
||||
model(input_features=input_features, labels=labels)
|
||||
|
||||
def test_labels_sequence_max_length_error(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
input_features = input_dict["input_features"]
|
||||
|
||||
labels_length = config.max_target_positions + 1
|
||||
labels = torch.ones(1, labels_length, dtype=torch.int64)
|
||||
|
||||
model = model_class(config)
|
||||
with self.assertRaises(ValueError):
|
||||
model(input_features=input_features, labels=labels)
|
||||
|
||||
def test_labels_sequence_max_length_error_after_changing_config(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
input_features = input_dict["input_features"]
|
||||
|
||||
labels_length = config.max_target_positions + 1
|
||||
labels = torch.ones(1, labels_length, dtype=torch.int64)
|
||||
|
||||
new_max_length = config.max_target_positions + 100
|
||||
model.config.max_length = new_max_length
|
||||
model.generation_config.max_length = new_max_length
|
||||
config.max_target_positions = new_max_length
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
model(input_features=input_features, labels=labels)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
|
||||
Reference in New Issue
Block a user