Flax Speech-Encoder-Decoder Model (#15613)

* rebase

* Delete shift tokens func

* downsample decoder input seq len for init

* correct attention mask

* add tests

* pt flax cross test

* make fixup

* init file for import

* change pt-flax cross test threshold

* pt-flax test logits only

* move tests

* make repo-consistency

* consistent indentation

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sanchit Gandhi
2022-02-28 12:22:36 +01:00
committed by GitHub
parent 935a76d90d
commit e3342edc4e
10 changed files with 1509 additions and 2 deletions

View File

@@ -215,6 +215,7 @@ def get_model_modules():
"modeling_flax_encoder_decoder",
"modeling_flax_utils",
"modeling_speech_encoder_decoder",
"modeling_flax_speech_encoder_decoder",
"modeling_flax_vision_encoder_decoder",
"modeling_transfo_xl_utilities",
"modeling_tf_auto",
@@ -290,6 +291,7 @@ def get_model_test_files():
"test_modeling_common",
"test_modeling_encoder_decoder",
"test_modeling_flax_encoder_decoder",
"test_modeling_flax_speech_encoder_decoder",
"test_modeling_marian",
"test_modeling_tf_common",
"test_modeling_tf_encoder_decoder",