Flax Speech-Encoder-Decoder Model (#15613)

* rebase

* Delete shift tokens func

* downsample decoder input seq len for init

* correct attention mask

* add tests

* pt flax cross test

* make fixup

* init file for import

* change pt-flax cross test threshold

* pt-flax test logits only

* move tests

* make repo-consistency

* consistent indentation

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sanchit Gandhi
2022-02-28 12:22:36 +01:00
committed by GitHub
parent 935a76d90d
commit e3342edc4e
10 changed files with 1509 additions and 2 deletions

View File

@@ -230,7 +230,7 @@ Flax), PyTorch, and/or TensorFlow.
| SegFormer | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
| Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | |
| Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | |
| Speech2Text | ✅ | ❌ | ✅ | ✅ | ❌ |
| Speech2Text2 | ✅ | ❌ | ❌ | ❌ | ❌ |
| Splinter | ✅ | ✅ | ✅ | ❌ | ❌ |

View File

@@ -33,3 +33,9 @@ An example of how to use a [`SpeechEncoderDecoderModel`] for inference can be se
[[autodoc]] SpeechEncoderDecoderModel
- forward
- from_encoder_decoder_pretrained
## FlaxSpeechEncoderDecoderModel
[[autodoc]] FlaxSpeechEncoderDecoderModel
- __call__
- from_encoder_decoder_pretrained