[Wav2Vec2] Rename model's feature extractor to feature encoder (#14959)

* rename classes

* clean up more namings

* remove bogus file

* Apply suggestions from code review

* Apply suggestions from code review

* replace more names

* more regex replace

* make style

* correct

* correct more

* make style

* finish

* correct more in wav2vec2

* make style

* improve freeze_extractor

* add aliases

* add tf aliases
This commit is contained in:
Patrick von Platen
2021-12-28 20:33:23 +01:00
committed by GitHub
parent 1bfa347707
commit 600496fa50
32 changed files with 658 additions and 215 deletions

View File

@@ -17,6 +17,7 @@
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from random import randint
from typing import Optional
@@ -76,24 +77,24 @@ class DataTrainingArguments:
eval_file: Optional[str] = field(
default=None, metadata={"help": "A file containing the validation audio paths and labels."}
)
train_split_name: Optional[str] = field(
train_split_name: str = field(
default="train",
metadata={
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
},
)
eval_split_name: Optional[str] = field(
eval_split_name: str = field(
default="validation",
metadata={
"help": "The name of the training data set split to use (via the datasets library). Defaults to "
"'validation'"
},
)
audio_column_name: Optional[str] = field(
audio_column_name: str = field(
default="audio",
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
)
label_column_name: Optional[str] = field(
label_column_name: str = field(
default="label", metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'"}
)
max_train_samples: Optional[int] = field(
@@ -110,7 +111,7 @@ class DataTrainingArguments:
"value if set."
},
)
max_length_seconds: Optional[float] = field(
max_length_seconds: float = field(
default=20,
metadata={"help": "Audio clips will be randomly cut to this length during training if the value is set."},
)
@@ -136,11 +137,13 @@ class ModelArguments:
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
freeze_feature_extractor: Optional[bool] = field(
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
feature_extractor_name: Optional[str] = field(
default=None, metadata={"help": "Name or path of preprocessor config."}
)
attention_mask: Optional[bool] = field(
freeze_feature_encoder: bool = field(
default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
)
attention_mask: bool = field(
default=True, metadata={"help": "Whether to generate an attention mask in the feature extractor."}
)
use_auth_token: bool = field(
@@ -150,6 +153,24 @@ class ModelArguments:
"with private models)."
},
)
freeze_feature_extractor: Optional[bool] = field(
default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
def __post_init__(self):
if not self.freeze_feature_extractor and self.freeze_feature_encoder:
warnings.warn(
"The argument `--freeze_feature_extractor` is deprecated and "
"will be removed in a future version. Use `--freeze_feature_encoder`"
"instead. Setting `freeze_feature_encoder==True`.",
FutureWarning,
)
if self.freeze_feature_extractor and not self.freeze_feature_encoder:
raise ValueError(
"The argument `--freeze_feature_extractor` is deprecated and "
"should not be used in combination with `--freeze_feature_encoder`."
"Only make use of `--freeze_feature_encoder`."
)
def main():
@@ -302,8 +323,8 @@ def main():
)
# freeze the convolutional waveform encoder
if model_args.freeze_feature_extractor:
model.freeze_feature_extractor()
if model_args.freeze_feature_encoder:
model.freeze_feature_encoder()
if training_args.do_train:
if data_args.max_train_samples is not None: