Making CTC training example more general (#28582)
* add w2v2bert compatibility * Update examples/pytorch/speech-recognition/run_speech_recognition_ctc.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -132,10 +132,17 @@ class ModelArguments:
|
|||||||
ctc_loss_reduction: Optional[str] = field(
|
ctc_loss_reduction: Optional[str] = field(
|
||||||
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
|
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
|
||||||
)
|
)
|
||||||
|
ctc_zero_infinity: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly"
|
||||||
|
" occur when the inputs are too short to be aligned to the targets."
|
||||||
|
},
|
||||||
|
)
|
||||||
add_adapter: Optional[bool] = field(
|
add_adapter: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2BERT Encoder. Can be very"
|
"help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2Bert Encoder. Can be very"
|
||||||
"useful to downsample the output length."
|
"useful to downsample the output length."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -316,11 +323,14 @@ class DataCollatorCTCWithPadding:
|
|||||||
padding: Union[bool, str] = "longest"
|
padding: Union[bool, str] = "longest"
|
||||||
pad_to_multiple_of: Optional[int] = None
|
pad_to_multiple_of: Optional[int] = None
|
||||||
pad_to_multiple_of_labels: Optional[int] = None
|
pad_to_multiple_of_labels: Optional[int] = None
|
||||||
|
feature_extractor_input_name: Optional[str] = "input_values"
|
||||||
|
|
||||||
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||||
# split inputs and labels since they have to be of different lengths and need
|
# split inputs and labels since they have to be of different lengths and need
|
||||||
# different padding methods
|
# different padding methods
|
||||||
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
input_features = [
|
||||||
|
{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features
|
||||||
|
]
|
||||||
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
||||||
|
|
||||||
batch = self.processor.pad(
|
batch = self.processor.pad(
|
||||||
@@ -606,6 +616,7 @@ def main():
|
|||||||
"gradient_checkpointing": training_args.gradient_checkpointing,
|
"gradient_checkpointing": training_args.gradient_checkpointing,
|
||||||
"layerdrop": model_args.layerdrop,
|
"layerdrop": model_args.layerdrop,
|
||||||
"ctc_loss_reduction": model_args.ctc_loss_reduction,
|
"ctc_loss_reduction": model_args.ctc_loss_reduction,
|
||||||
|
"ctc_zero_infinity": model_args.ctc_zero_infinity,
|
||||||
"pad_token_id": tokenizer.pad_token_id,
|
"pad_token_id": tokenizer.pad_token_id,
|
||||||
"vocab_size": len(tokenizer),
|
"vocab_size": len(tokenizer),
|
||||||
"activation_dropout": model_args.activation_dropout,
|
"activation_dropout": model_args.activation_dropout,
|
||||||
@@ -643,6 +654,7 @@ def main():
|
|||||||
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
||||||
audio_column_name = data_args.audio_column_name
|
audio_column_name = data_args.audio_column_name
|
||||||
num_workers = data_args.preprocessing_num_workers
|
num_workers = data_args.preprocessing_num_workers
|
||||||
|
feature_extractor_input_name = feature_extractor.model_input_names[0]
|
||||||
|
|
||||||
# `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
|
# `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
|
||||||
phoneme_language = data_args.phoneme_language
|
phoneme_language = data_args.phoneme_language
|
||||||
@@ -654,8 +666,9 @@ def main():
|
|||||||
sample = batch[audio_column_name]
|
sample = batch[audio_column_name]
|
||||||
|
|
||||||
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
|
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
|
||||||
batch["input_values"] = inputs.input_values[0]
|
batch[feature_extractor_input_name] = getattr(inputs, feature_extractor_input_name)[0]
|
||||||
batch["input_length"] = len(batch["input_values"])
|
# take length of raw audio waveform
|
||||||
|
batch["input_length"] = len(sample["array"].squeeze())
|
||||||
|
|
||||||
# encode targets
|
# encode targets
|
||||||
additional_kwargs = {}
|
additional_kwargs = {}
|
||||||
@@ -736,7 +749,9 @@ def main():
|
|||||||
processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir)
|
processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir)
|
||||||
|
|
||||||
# Instantiate custom data collator
|
# Instantiate custom data collator
|
||||||
data_collator = DataCollatorCTCWithPadding(processor=processor)
|
data_collator = DataCollatorCTCWithPadding(
|
||||||
|
processor=processor, feature_extractor_input_name=feature_extractor_input_name
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize Trainer
|
# Initialize Trainer
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
|||||||
Reference in New Issue
Block a user