[Xtreme-S] fix some namings (#16183)
This commit is contained in:
committed by
GitHub
parent
99fd3eb4a5
commit
c2dc89be62
@@ -81,9 +81,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
|
|||||||
python -m torch.distributed.launch \
|
python -m torch.distributed.launch \
|
||||||
--nproc_per_node=8 \
|
--nproc_per_node=8 \
|
||||||
run_xtreme_s.py \
|
run_xtreme_s.py \
|
||||||
|
--task="mls" \
|
||||||
|
--language="all" \
|
||||||
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \
|
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \
|
||||||
--dataset_name="google/xtreme_s" \
|
|
||||||
--dataset_config_name="mls.all" \
|
|
||||||
--eval_split_name="test" \
|
--eval_split_name="test" \
|
||||||
--output_dir="xtreme_s_xlsr_300m_mls" \
|
--output_dir="xtreme_s_xlsr_300m_mls" \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
@@ -94,7 +94,6 @@ python -m torch.distributed.launch \
|
|||||||
--learning_rate="3e-4" \
|
--learning_rate="3e-4" \
|
||||||
--warmup_steps=3000 \
|
--warmup_steps=3000 \
|
||||||
--evaluation_strategy="steps" \
|
--evaluation_strategy="steps" \
|
||||||
--target_column_name="transcription" \
|
|
||||||
--max_duration_in_seconds=20 \
|
--max_duration_in_seconds=20 \
|
||||||
--save_steps=500 \
|
--save_steps=500 \
|
||||||
--eval_steps=500 \
|
--eval_steps=500 \
|
||||||
@@ -126,10 +125,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
|
|||||||
python -m torch.distributed.launch \
|
python -m torch.distributed.launch \
|
||||||
--nproc_per_node=2 \
|
--nproc_per_node=2 \
|
||||||
run_xtreme_s.py \
|
run_xtreme_s.py \
|
||||||
|
--task="minds14" \
|
||||||
|
--language="all" \
|
||||||
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \
|
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \
|
||||||
--dataset_name="google/xtreme_s" \
|
|
||||||
--dataset_config_name="minds14.all" \
|
|
||||||
--eval_split_name="test" \
|
|
||||||
--output_dir="xtreme_s_xlsr_300m_minds14" \
|
--output_dir="xtreme_s_xlsr_300m_minds14" \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--num_train_epochs=50 \
|
--num_train_epochs=50 \
|
||||||
@@ -139,7 +137,6 @@ python -m torch.distributed.launch \
|
|||||||
--learning_rate="3e-4" \
|
--learning_rate="3e-4" \
|
||||||
--warmup_steps=1500 \
|
--warmup_steps=1500 \
|
||||||
--evaluation_strategy="steps" \
|
--evaluation_strategy="steps" \
|
||||||
--target_column_name="intent_class" \
|
|
||||||
--max_duration_in_seconds=30 \
|
--max_duration_in_seconds=30 \
|
||||||
--save_steps=200 \
|
--save_steps=200 \
|
||||||
--eval_steps=200 \
|
--eval_steps=200 \
|
||||||
@@ -62,6 +62,17 @@ def list_field(default=None, metadata=None):
|
|||||||
return field(default_factory=lambda: default, metadata=metadata)
|
return field(default_factory=lambda: default, metadata=metadata)
|
||||||
|
|
||||||
|
|
||||||
|
TASK_TO_TARGET_COLUMN_NAME = {
|
||||||
|
"fleurs-asr": "transcription",
|
||||||
|
"fleurs-lang_id": "lang_id",
|
||||||
|
"mls": "transcription",
|
||||||
|
"voxpopuli": "transcription",
|
||||||
|
"covost2": "translation",
|
||||||
|
"minds14": "intent_class",
|
||||||
|
"babel": "transcription",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
"""
|
"""
|
||||||
@@ -144,8 +155,16 @@ class DataTrainingArguments:
|
|||||||
default="xtreme_s",
|
default="xtreme_s",
|
||||||
metadata={"help": "The name of the dataset to use (via the datasets library). Defaults to 'xtreme_s'"},
|
metadata={"help": "The name of the dataset to use (via the datasets library). Defaults to 'xtreme_s'"},
|
||||||
)
|
)
|
||||||
dataset_config_name: str = field(
|
task: str = field(
|
||||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The task name of the benchmark to use (via the datasets library). Should be on of: "
|
||||||
|
"'fleurs-asr', 'mls', 'voxpopuli', 'covost2', 'minds14', 'fleurs-lang_id', 'babel'."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
language: str = field(
|
||||||
|
default="all",
|
||||||
|
metadata={"help": "The language id as defined in the datasets config name or `all` for all languages."},
|
||||||
)
|
)
|
||||||
train_split_name: str = field(
|
train_split_name: str = field(
|
||||||
default="train",
|
default="train",
|
||||||
@@ -160,6 +179,13 @@ class DataTrainingArguments:
|
|||||||
"Defaults to 'validation'"
|
"Defaults to 'validation'"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
predict_split_name: str = field(
|
||||||
|
default="test",
|
||||||
|
metadata={
|
||||||
|
"help": "The name of the prediction data set split to use (via the datasets library). "
|
||||||
|
"Defaults to 'test'"
|
||||||
|
},
|
||||||
|
)
|
||||||
audio_column_name: str = field(
|
audio_column_name: str = field(
|
||||||
default="audio",
|
default="audio",
|
||||||
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
|
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
|
||||||
@@ -192,6 +218,13 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
max_predict_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
|
"value if set."
|
||||||
|
},
|
||||||
|
)
|
||||||
chars_to_ignore: Optional[List[str]] = list_field(
|
chars_to_ignore: Optional[List[str]] = list_field(
|
||||||
default=', ? . ! - ; : " “ % ‘ ” <20>'.split(" "),
|
default=', ? . ! - ; : " “ % ‘ ” <20>'.split(" "),
|
||||||
metadata={"help": "A list of characters to remove from the transcripts."},
|
metadata={"help": "A list of characters to remove from the transcripts."},
|
||||||
@@ -387,22 +420,31 @@ def main():
|
|||||||
|
|
||||||
# 1. First, let's load the dataset
|
# 1. First, let's load the dataset
|
||||||
raw_datasets = DatasetDict()
|
raw_datasets = DatasetDict()
|
||||||
if data_args.dataset_config_name is None:
|
task_name = data_args.task
|
||||||
|
lang_id = data_args.language
|
||||||
|
|
||||||
|
if task_name is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Set --dataset_config_name should be set to '<xtreme_s_subset>.<language(s)>' "
|
"Set --task should be set to '<xtreme_s_task>' " "(e.g. 'fleurs-asr', 'mls', 'covost2', 'minds14') "
|
||||||
"(e.g. 'mls.pl', 'covost2.en.tr', 'minds14.fr-FR') "
|
)
|
||||||
"or '<xtreme_s_subset>.all' for multi-lingual fine-tuning."
|
if lang_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Set --language should be set to the language id of the sub dataset "
|
||||||
|
"config to be used (e.g. 'pl', 'en.tr', 'fr-FR') or 'all'"
|
||||||
|
" for multi-lingual fine-tuning."
|
||||||
)
|
)
|
||||||
|
|
||||||
task_name = data_args.dataset_config_name.split(".")[0]
|
target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
|
||||||
target_column_name = data_args.target_column_name
|
|
||||||
# here we differentiate between tasks with text as the target and classification tasks
|
# here we differentiate between tasks with text as the target and classification tasks
|
||||||
is_text_target = target_column_name in ("transcription", "translation")
|
is_text_target = target_column_name in ("transcription", "translation")
|
||||||
|
|
||||||
|
config_name = ".".join([task_name.split("-")[0], lang_id])
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
raw_datasets["train"] = load_dataset(
|
raw_datasets["train"] = load_dataset(
|
||||||
data_args.dataset_name,
|
data_args.dataset_name,
|
||||||
data_args.dataset_config_name,
|
config_name,
|
||||||
split=data_args.train_split_name,
|
split=data_args.train_split_name,
|
||||||
use_auth_token=data_args.use_auth_token,
|
use_auth_token=data_args.use_auth_token,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
@@ -432,7 +474,7 @@ def main():
|
|||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
raw_datasets["eval"] = load_dataset(
|
raw_datasets["eval"] = load_dataset(
|
||||||
data_args.dataset_name,
|
data_args.dataset_name,
|
||||||
data_args.dataset_config_name,
|
config_name,
|
||||||
split=data_args.eval_split_name,
|
split=data_args.eval_split_name,
|
||||||
use_auth_token=data_args.use_auth_token,
|
use_auth_token=data_args.use_auth_token,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
@@ -441,6 +483,18 @@ def main():
|
|||||||
if data_args.max_eval_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
||||||
|
|
||||||
|
if training_args.do_predict:
|
||||||
|
raw_datasets["predict"] = load_dataset(
|
||||||
|
data_args.dataset_name,
|
||||||
|
config_name,
|
||||||
|
split=data_args.predict_split_name,
|
||||||
|
use_auth_token=data_args.use_auth_token,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if data_args.max_predict_samples is not None:
|
||||||
|
raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples))
|
||||||
|
|
||||||
# 2. We remove some special characters from the datasets
|
# 2. We remove some special characters from the datasets
|
||||||
# that make training complicated and do not help in transcribing the speech
|
# that make training complicated and do not help in transcribing the speech
|
||||||
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
||||||
@@ -757,24 +811,25 @@ def main():
|
|||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
results = {}
|
||||||
if training_args.do_eval:
|
if training_args.do_predict:
|
||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Predicte ***")
|
||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate(vectorized_datasets["predict"])
|
||||||
max_eval_samples = (
|
max_predict_samples = (
|
||||||
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
|
data_args.max_predict_samples
|
||||||
|
if data_args.max_predict_samples is not None
|
||||||
|
else len(vectorized_datasets["predict"])
|
||||||
)
|
)
|
||||||
metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
|
metrics["predict_samples"] = min(max_predict_samples, len(vectorized_datasets["predict"]))
|
||||||
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("predict", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("predict", metrics)
|
||||||
|
|
||||||
# Write model card and (optionally) push to hub
|
# Write model card and (optionally) push to hub
|
||||||
config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"finetuned_from": model_args.model_name_or_path,
|
"finetuned_from": model_args.model_name_or_path,
|
||||||
"tasks": "speech-recognition",
|
"tasks": task_name,
|
||||||
"tags": ["automatic-speech-recognition", data_args.dataset_name],
|
"tags": [task_name, data_args.dataset_name],
|
||||||
"dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
|
"dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}, Predict split: {data_args.predict_split_name}",
|
||||||
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
|
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
|
||||||
}
|
}
|
||||||
if "common_voice" in data_args.dataset_name:
|
if "common_voice" in data_args.dataset_name:
|
||||||
Reference in New Issue
Block a user