From 7ca46335553609e4852dcb018c73cd5215e6e25a Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Mon, 28 Mar 2022 14:14:10 +0200 Subject: [PATCH] [FlaxSpeechEncoderDecoderModel] Ensure Input and Output Word Embeddings Are **Not** Tied (#16444) * [FlaxSpeechEncoderDecoderModel] Ensure Input and Output Word Embeddings Are **Not** Tied * rebase --- .../modeling_flax_speech_encoder_decoder.py | 5 +++++ .../test_modeling_flax_speech_encoder_decoder.py | 1 + .../test_modeling_speech_encoder_decoder.py | 1 + 3 files changed, 7 insertions(+) diff --git a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py index 16067e2c20..aff3953b84 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py @@ -347,6 +347,8 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`." ) + # make sure input & output embeddings are not tied + config.tie_word_embeddings = False module = self.module_class(config=config, dtype=dtype, **kwargs) if input_shape is None: @@ -890,6 +892,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): dtype = kwargs.pop("dtype", jnp.float32) config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) + # make sure input & output word embeddings are not tied + config.tie_word_embeddings = False + # init model model = cls(config, dtype=dtype) model.params["encoder"] = encoder.params diff --git a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py index 113e867f3a..403255c4ce 100644 --- a/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py +++ b/tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py @@ -79,6 +79,7 @@ class FlaxEncoderDecoderMixin: enc_dec_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config) self.assertTrue(enc_dec_model.config.is_encoder_decoder) + self.assertFalse(enc_dec_model.config.tie_word_embeddings) outputs_encoder_decoder = enc_dec_model( inputs=inputs, diff --git a/tests/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py b/tests/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py index 7fc26a76ba..4bc7c52943 100644 --- a/tests/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py +++ b/tests/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py @@ -72,6 +72,7 @@ class EncoderDecoderMixin: enc_dec_model.eval() self.assertTrue(enc_dec_model.config.is_encoder_decoder) + self.assertFalse(enc_dec_model.config.tie_word_embeddings) outputs_encoder_decoder = enc_dec_model( input_values=input_values,