diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py index c841e99df2..f27f700162 100755 --- a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py +++ b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py @@ -113,6 +113,12 @@ class ModelArguments: suppress_tokens: List[int] = field( default=None, metadata={"help": "A list of tokens that will be suppressed at generation."} ) + apply_spec_augment: bool = field( + default=False, + metadata={ + "help": "Whether to apply *SpecAugment* data augmentation to the input features. This is currently only relevant for Wav2Vec2, HuBERT, WavLM and Whisper models." + }, + ) @dataclass @@ -127,10 +133,6 @@ class DataTrainingArguments: dataset_config_name: Optional[str] = field( default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} ) - text_column: Optional[str] = field( - default=None, - metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, - ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) @@ -227,10 +229,13 @@ class DataCollatorSpeechSeq2SeqWithPadding: The processor used for processing the data. decoder_start_token_id (`int`) The begin-of-sentence of the decoder. + forward_attention_mask (`bool`) + Whether to return attention_mask. """ processor: Any decoder_start_token_id: int + forward_attention_mask: bool def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: # split inputs and labels since they have to be of different lengths and need @@ -241,6 +246,9 @@ class DataCollatorSpeechSeq2SeqWithPadding: batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") + if self.forward_attention_mask: + batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features]) + labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") # replace padding with -100 to ignore loss correctly @@ -367,6 +375,10 @@ def main(): config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens}) + # SpecAugment for whisper models + if getattr(config, "model_type", None) == "whisper": + config.update({"apply_spec_augment": model_args.apply_spec_augment}) + feature_extractor = AutoFeatureExtractor.from_pretrained( model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, @@ -418,6 +430,12 @@ def main(): text_column_name = data_args.text_column_name model_input_name = feature_extractor.model_input_names[0] do_lower_case = data_args.do_lower_case + # if SpecAugment is used for whisper models, return attention_mask to guide the mask along time axis + forward_attention_mask = ( + getattr(config, "model_type", None) == "whisper" + and getattr(config, "apply_spec_augment", False) + and getattr(config, "mask_time_prob", 0) > 0 + ) if data_args.max_train_samples is not None: raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples)) @@ -428,10 +446,14 @@ def main(): def prepare_dataset(batch): # process audio sample = batch[audio_column_name] - inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) + inputs = feature_extractor( + sample["array"], sampling_rate=sample["sampling_rate"], return_attention_mask=forward_attention_mask + ) # process audio length batch[model_input_name] = inputs.get(model_input_name)[0] batch["input_length"] = len(sample["array"]) + if forward_attention_mask: + batch["attention_mask"] = inputs.get("attention_mask")[0] # process targets input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] @@ -496,6 +518,7 @@ def main(): data_collator = DataCollatorSpeechSeq2SeqWithPadding( processor=processor, decoder_start_token_id=model.config.decoder_start_token_id, + forward_attention_mask=forward_attention_mask, ) # 11. Initialize Trainer