[Examples] Add an official audio classification example (#13722)
* Restore broken merge * Additional args, DDP, remove CommonLanguage * Update examples for V100, add training results * Style * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Remove custom datasets for simplicity, apply suggestions from code review * Add the attention_mask flag, reorganize README Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -40,12 +40,14 @@ SRC_DIRS = [
|
||||
"translation",
|
||||
"image-classification",
|
||||
"speech-recognition",
|
||||
"audio-classification",
|
||||
]
|
||||
]
|
||||
sys.path.extend(SRC_DIRS)
|
||||
|
||||
|
||||
if SRC_DIRS is not None:
|
||||
import run_audio_classification
|
||||
import run_clm
|
||||
import run_generation
|
||||
import run_glue
|
||||
@@ -410,3 +412,38 @@ class ExamplesTests(TestCasePlus):
|
||||
run_speech_recognition_ctc.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
def test_run_audio_classification(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_audio_classification.py
|
||||
--output_dir {tmp_dir}
|
||||
--model_name_or_path hf-internal-testing/tiny-random-wav2vec2
|
||||
--dataset_name anton-l/superb_demo
|
||||
--dataset_config_name ks
|
||||
--train_split_name test
|
||||
--eval_split_name test
|
||||
--audio_column_name file
|
||||
--label_column_name label
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate 1e-4
|
||||
--per_device_train_batch_size 2
|
||||
--per_device_eval_batch_size 1
|
||||
--remove_unused_columns False
|
||||
--overwrite_output_dir True
|
||||
--num_train_epochs 10
|
||||
--max_steps 50
|
||||
--seed 42
|
||||
""".split()
|
||||
|
||||
if is_cuda_and_apex_available():
|
||||
testargs.append("--fp16")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_audio_classification.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
Reference in New Issue
Block a user