[Speech2Text Doc] Fix docs (#16611)

* [Speech2Text Doc] Fix docs

* apply ydshiehs suggestions
This commit is contained in:
Patrick von Platen
2022-04-06 14:19:00 +02:00
committed by GitHub
parent fb3d0df454
commit c65633156b
2 changed files with 11 additions and 23 deletions

View File

@@ -47,25 +47,19 @@ be installed as follows: `apt install libsndfile1-dev`
>>> import torch
>>> from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
>>> processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> inputs = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="pt")
>>> generated_ids = model.generate(input_ids=inputs["input_features"], attention_mask=inputs["attention_mask"])
>>> inputs = processor(ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt")
>>> generated_ids = model.generate(inputs["input_features"], attention_mask=inputs["attention_mask"])
>>> transcription = processor.batch_decode(generated_ids)
>>> transcription
['mister quilter is the apostle of the middle classes and we are glad to welcome his gospel']
```
- Multilingual speech translation
@@ -80,29 +74,22 @@ be installed as follows: `apt install libsndfile1-dev`
>>> import torch
>>> from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
>>> processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> inputs = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="pt")
>>> inputs = processor(ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt")
>>> generated_ids = model.generate(
... input_ids=inputs["input_features"],
... inputs["input_features"],
... attention_mask=inputs["attention_mask"],
... forced_bos_token_id=processor.tokenizer.lang_code_to_id["fr"],
... )
>>> translation = processor.batch_decode(generated_ids)
>>> translation
["<lang:fr> (Vidéo) Si M. Kilder est l'apossible des classes moyennes, et nous sommes heureux d'être accueillis dans son évangile."]
```
See the [model hub](https://huggingface.co/models?filter=speech_to_text) to look for Speech2Text checkpoints.