Flax Speech-Encoder-Decoder Model (#15613)

* rebase

* Delete shift tokens func

* downsample decoder input seq len for init

* correct attention mask

* add tests

* pt flax cross test

* make fixup

* init file for import

* change pt-flax cross test threshold

* pt-flax test logits only

* move tests

* make repo-consistency

* consistent indentation

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sanchit Gandhi
2022-02-28 12:22:36 +01:00
committed by GitHub
parent 935a76d90d
commit e3342edc4e
10 changed files with 1509 additions and 2 deletions

View File

@@ -33,3 +33,9 @@ An example of how to use a [`SpeechEncoderDecoderModel`] for inference can be se
[[autodoc]] SpeechEncoderDecoderModel
- forward
- from_encoder_decoder_pretrained
## FlaxSpeechEncoderDecoderModel
[[autodoc]] FlaxSpeechEncoderDecoderModel
- __call__
- from_encoder_decoder_pretrained