[FlaxWav2Vec2Model] Fix bug in attention mask (#16725)

* [FlaxWav2Vec2Model] Fix bug in attention mask

* more fixes

* add (Flax)SpeechEncoderDecoderModel PT-FX cross-test
This commit is contained in:
Sanchit Gandhi
2022-04-12 19:48:24 +02:00
committed by GitHub
parent 6adefba3f0
commit a960406722
2 changed files with 24 additions and 17 deletions

View File

@@ -539,6 +539,12 @@ class FlaxEncoderDecoderMixin:
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
# check `add_adapter` works as expected
config.add_adapter = True
self.assertTrue(config.add_adapter)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()