[Examples] Fix typos in run speech recognition seq2seq (#19514)
This commit is contained in:
@@ -195,7 +195,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|||||||
Data collator that will dynamically pad the inputs received.
|
Data collator that will dynamically pad the inputs received.
|
||||||
Args:
|
Args:
|
||||||
processor ([`Wav2Vec2Processor`])
|
processor ([`Wav2Vec2Processor`])
|
||||||
The processor used for proccessing the data.
|
The processor used for processing the data.
|
||||||
decoder_start_token_id (`int`)
|
decoder_start_token_id (`int`)
|
||||||
The begin-of-sentence of the decoder.
|
The begin-of-sentence of the decoder.
|
||||||
"""
|
"""
|
||||||
@@ -204,7 +204,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|||||||
decoder_start_token_id: int
|
decoder_start_token_id: int
|
||||||
|
|
||||||
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||||
# split inputs and labels since they have to be of different lenghts and need
|
# split inputs and labels since they have to be of different lengths and need
|
||||||
# different padding methods
|
# different padding methods
|
||||||
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
||||||
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
||||||
@@ -271,7 +271,7 @@ def main():
|
|||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
logger.info("Training/evaluation parameters %s", training_args)
|
logger.info("Training/evaluation parameters %s", training_args)
|
||||||
|
|
||||||
# 3. Detecting last checkpoint and eventualy continue from last checkpoint
|
# 3. Detecting last checkpoint and eventually continue from last checkpoint
|
||||||
last_checkpoint = None
|
last_checkpoint = None
|
||||||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
||||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||||
@@ -360,7 +360,7 @@ def main():
|
|||||||
if model_args.freeze_feature_encoder:
|
if model_args.freeze_feature_encoder:
|
||||||
model.freeze_feature_encoder()
|
model.freeze_feature_encoder()
|
||||||
|
|
||||||
# 6. Resample speech dataset if necassary
|
# 6. Resample speech dataset if necessary
|
||||||
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
||||||
if dataset_sampling_rate != feature_extractor.sampling_rate:
|
if dataset_sampling_rate != feature_extractor.sampling_rate:
|
||||||
raw_datasets = raw_datasets.cast_column(
|
raw_datasets = raw_datasets.cast_column(
|
||||||
|
|||||||
Reference in New Issue
Block a user