[Wav2Vec2ProcessorWithLM] improve decoder downlaod (#15040)

This commit is contained in:
Patrick von Platen
2022-01-11 11:59:38 +01:00
committed by GitHub
parent 6ea6266625
commit efb35a4107
2 changed files with 22 additions and 1 deletions

View File

@@ -18,6 +18,7 @@ import shutil
import tempfile
import unittest
from multiprocessing import Pool
from pathlib import Path
import numpy as np
@@ -234,3 +235,16 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
self.assertListEqual(decoded_decoder, decoded_processor)
self.assertListEqual(["<s> </s> </s>", "<s> <s> </s>"], decoded_processor)
def test_decoder_download_ignores_files(self):
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
language_model = processor.decoder.model_container[processor.decoder._model_key]
path_to_cached_dir = Path(language_model._kenlm_model.path.decode("utf-8")).parent.parent.absolute()
downloaded_decoder_files = os.listdir(path_to_cached_dir)
# test that only decoder relevant files from
# https://huggingface.co/hf-internal-testing/processor_with_lm/tree/main
# are downloaded and none of the rest (e.g. README.md, ...)
self.assertListEqual(downloaded_decoder_files, ["alphabet.json", "language_model"])