make wav2vec2 test deterministic (#10714)
This commit is contained in:
committed by
GitHub
parent
6bef764506
commit
d9e693e1d0
@@ -515,6 +515,8 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
|
ids = [f"1272-141231-000{i}" for i in range(num_samples)]
|
||||||
|
|
||||||
# map files to raw
|
# map files to raw
|
||||||
def map_to_array(batch):
|
def map_to_array(batch):
|
||||||
speech, _ = sf.read(batch["file"])
|
speech, _ = sf.read(batch["file"])
|
||||||
@@ -522,7 +524,8 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||||
ds = ds.select(range(num_samples)).map(map_to_array)
|
|
||||||
|
ds = ds.filter(lambda x: x["id"] in ids).sort("id").map(map_to_array)
|
||||||
|
|
||||||
return ds["speech"][:num_samples]
|
return ds["speech"][:num_samples]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user