Add validation for maximum sequence length in modeling_whisper.py (#33196)

* Add validation for maximum sequence length in modeling_whisper.py

Added a validation check to ensure that the sequence length of labels does not exceed the maximum allowed length of 448 tokens. If the sequence length exceeds this limit, a ValueError is raised with a descriptive error message.

This change prevents the model from encountering errors or unexpected behavior due to excessively long sequences during training or fine-tuning, ensuring consistent input dimensions and improving overall robustness.

* Change exception message in src/transformers/models/whisper/modeling_whisper.py

The exception message is for whisper's label's sequence max length.

Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>

* Change 448 to config.max_target_positions in src/transformers/models/whisper/modeling_whisper.py

It's for whisper's config.max_target_positions.

Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>

* Change method's documentation in src/transformers/models/whisper/modeling_whisper.py

* Add test for maximum label's sequence length in test_modeling_whisper.py

* Add self to modeling_whisper.py

* Update test_modeling_whisper.py with respect to automatic validations

* Update modeling_whisper.py with respect to ci/circleci: check_code_quality

* Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality

* Update test_modeling_whisper.py with respect to ci/circleci: tests_generate

* Update test_modeling_whisper.py with respect to ci/circleci: tests_generate

* Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality

* Separate test_labels_sequence_max_length tests in test_modeling_whisper.py

* Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality

* Remove assert from test_modeling_whisper.py

* Add max_target_positions to WhisperModelTester in test_modeling_whisper.py

* Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality

* Update test_modeling_whisper.py with respect to ci/circleci: tests_generate

* Update test_modeling_whisper.py

* Change test_labels_sequence_max_length_error_after_changing_config in test_modeling_whisper.py

* Change self.config.max_target_positions to self.max_target_positions modeling_whisper.py

* Add new tests in test_modeling_whisper.py

* Update test_modeling_whisper.py

---------

Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
This commit is contained in:
Amir Mohammad Fakhimi
2024-09-06 15:39:49 +03:30
committed by GitHub
parent 363301f221
commit 3314fe1760
2 changed files with 63 additions and 1 deletions

View File

@@ -1671,6 +1671,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
super().__init__(config) super().__init__(config)
self.model = WhisperModel(config) self.model = WhisperModel(config)
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False) self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.max_target_positions = config.max_target_positions
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@@ -1723,7 +1724,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
only computed for the tokens with labels in `[0, ..., config.vocab_size]`. only computed for the tokens with labels in `[0, ..., config.vocab_size]`. `sequence_length` should be smaller than or equal to `config.max_target_positions`.
Returns: Returns:
@@ -1751,6 +1752,10 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if labels.shape[1] > self.max_target_positions:
raise ValueError(
f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.max_target_positions} tokens."
)
if decoder_input_ids is None and decoder_inputs_embeds is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id

View File

@@ -1676,6 +1676,63 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
past_key_values=past_key_values, past_key_values=past_key_values,
) )
def test_labels_sequence_max_length_correct(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_generative_model_classes:
input_features = input_dict["input_features"]
labels_length = config.max_target_positions
labels = torch.ones(1, labels_length, dtype=torch.int64)
model = model_class(config)
model(input_features=input_features, labels=labels)
def test_labels_sequence_max_length_correct_after_changing_config(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_generative_model_classes:
input_features = input_dict["input_features"]
config.max_target_positions += 100
labels_length = config.max_target_positions
labels = torch.ones(1, labels_length, dtype=torch.int64)
model = model_class(config)
model(input_features=input_features, labels=labels)
def test_labels_sequence_max_length_error(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_generative_model_classes:
input_features = input_dict["input_features"]
labels_length = config.max_target_positions + 1
labels = torch.ones(1, labels_length, dtype=torch.int64)
model = model_class(config)
with self.assertRaises(ValueError):
model(input_features=input_features, labels=labels)
def test_labels_sequence_max_length_error_after_changing_config(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_generative_model_classes:
model = model_class(config)
input_features = input_dict["input_features"]
labels_length = config.max_target_positions + 1
labels = torch.ones(1, labels_length, dtype=torch.int64)
new_max_length = config.max_target_positions + 100
model.config.max_length = new_max_length
model.generation_config.max_length = new_max_length
config.max_target_positions = new_max_length
with self.assertRaises(ValueError):
model(input_features=input_features, labels=labels)
@require_torch @require_torch
@require_torchaudio @require_torchaudio