[Wav2Vec2] Rename model's feature extractor to feature encoder (#14959)
* rename classes * clean up more namings * remove bogus file * Apply suggestions from code review * Apply suggestions from code review * replace more names * more regex replace * make style * correct * correct more * make style * finish * correct more in wav2vec2 * make style * improve freeze_extractor * add aliases * add tf aliases
This commit is contained in:
committed by
GitHub
parent
1bfa347707
commit
600496fa50
@@ -17,6 +17,7 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from random import randint
|
||||
from typing import Optional
|
||||
@@ -76,24 +77,24 @@ class DataTrainingArguments:
|
||||
eval_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "A file containing the validation audio paths and labels."}
|
||||
)
|
||||
train_split_name: Optional[str] = field(
|
||||
train_split_name: str = field(
|
||||
default="train",
|
||||
metadata={
|
||||
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
||||
},
|
||||
)
|
||||
eval_split_name: Optional[str] = field(
|
||||
eval_split_name: str = field(
|
||||
default="validation",
|
||||
metadata={
|
||||
"help": "The name of the training data set split to use (via the datasets library). Defaults to "
|
||||
"'validation'"
|
||||
},
|
||||
)
|
||||
audio_column_name: Optional[str] = field(
|
||||
audio_column_name: str = field(
|
||||
default="audio",
|
||||
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
|
||||
)
|
||||
label_column_name: Optional[str] = field(
|
||||
label_column_name: str = field(
|
||||
default="label", metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'"}
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
@@ -110,7 +111,7 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_length_seconds: Optional[float] = field(
|
||||
max_length_seconds: float = field(
|
||||
default=20,
|
||||
metadata={"help": "Audio clips will be randomly cut to this length during training if the value is set."},
|
||||
)
|
||||
@@ -136,11 +137,13 @@ class ModelArguments:
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
|
||||
freeze_feature_extractor: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
||||
feature_extractor_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Name or path of preprocessor config."}
|
||||
)
|
||||
attention_mask: Optional[bool] = field(
|
||||
freeze_feature_encoder: bool = field(
|
||||
default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
|
||||
)
|
||||
attention_mask: bool = field(
|
||||
default=True, metadata={"help": "Whether to generate an attention mask in the feature extractor."}
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
@@ -150,6 +153,24 @@ class ModelArguments:
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
freeze_feature_extractor: Optional[bool] = field(
|
||||
default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.freeze_feature_extractor and self.freeze_feature_encoder:
|
||||
warnings.warn(
|
||||
"The argument `--freeze_feature_extractor` is deprecated and "
|
||||
"will be removed in a future version. Use `--freeze_feature_encoder`"
|
||||
"instead. Setting `freeze_feature_encoder==True`.",
|
||||
FutureWarning,
|
||||
)
|
||||
if self.freeze_feature_extractor and not self.freeze_feature_encoder:
|
||||
raise ValueError(
|
||||
"The argument `--freeze_feature_extractor` is deprecated and "
|
||||
"should not be used in combination with `--freeze_feature_encoder`."
|
||||
"Only make use of `--freeze_feature_encoder`."
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -302,8 +323,8 @@ def main():
|
||||
)
|
||||
|
||||
# freeze the convolutional waveform encoder
|
||||
if model_args.freeze_feature_extractor:
|
||||
model.freeze_feature_extractor()
|
||||
if model_args.freeze_feature_encoder:
|
||||
model.freeze_feature_encoder()
|
||||
|
||||
if training_args.do_train:
|
||||
if data_args.max_train_samples is not None:
|
||||
|
||||
Reference in New Issue
Block a user