[Docs] Fix Speech Encoder Decoder doc sample (#18346)
* [Docs] Fix Speech Encoder Decoder doc sample * improve pre-processing comment * make style
This commit is contained in:
@@ -85,25 +85,26 @@ As you can see, only 2 inputs are required for the model in order to compute a l
|
|||||||
speech inputs) and `labels` (which are the `input_ids` of the encoded target sequence).
|
speech inputs) and `labels` (which are the `input_ids` of the encoded target sequence).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import Wav2Vec2Processor, SpeechEncoderDecoderModel
|
>>> from transformers import AutoTokenizer, AutoFeatureExtractor, SpeechEncoderDecoderModel
|
||||||
>>> from datasets import load_dataset
|
>>> from datasets import load_dataset
|
||||||
|
|
||||||
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
>>> encoder_id = "facebook/wav2vec2-base-960h" # acoustic model encoder
|
||||||
>>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
>>> decoder_id = "bert-base-uncased" # text decoder
|
||||||
>>> model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
|
|
||||||
... "facebook/wav2vec2-base-960h", "bert-base-uncased"
|
|
||||||
... )
|
|
||||||
|
|
||||||
>>> model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
>>> feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)
|
||||||
>>> model.config.pad_token_id = processor.tokenizer.pad_token_id
|
>>> tokenizer = AutoTokenizer.from_pretrained(decoder_id)
|
||||||
|
>>> # Combine pre-trained encoder and pre-trained decoder to form a Seq2Seq model
|
||||||
|
>>> model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id)
|
||||||
|
|
||||||
>>> # load a speech input
|
>>> model.config.decoder_start_token_id = tokenizer.cls_token_id
|
||||||
|
>>> model.config.pad_token_id = tokenizer.pad_token_id
|
||||||
|
|
||||||
|
>>> # load an audio input and pre-process (normalise mean/std to 0/1)
|
||||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
>>> input_values = processor(ds[0]["audio"]["array"], return_tensors="pt").input_values
|
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values
|
||||||
|
|
||||||
>>> # load its corresponding transcription
|
>>> # load its corresponding transcription and tokenize to generate labels
|
||||||
>>> with processor.as_target_processor():
|
>>> labels = tokenizer(ds[0]["text"], return_tensors="pt").input_ids
|
||||||
... labels = processor(ds[0]["text"], return_tensors="pt").input_ids
|
|
||||||
|
|
||||||
>>> # the forward function automatically creates the correct decoder_input_ids
|
>>> # the forward function automatically creates the correct decoder_input_ids
|
||||||
>>> loss = model(input_values, labels=labels).loss
|
>>> loss = model(input_values, labels=labels).loss
|
||||||
|
|||||||
Reference in New Issue
Block a user