[FlaxSpeechEncoderDecoderModel] Ensure Input and Output Word Embeddings Are **Not** Tied (#16444)
* [FlaxSpeechEncoderDecoderModel] Ensure Input and Output Word Embeddings Are **Not** Tied * rebase
This commit is contained in:
@@ -347,6 +347,8 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# make sure input & output embeddings are not tied
|
||||||
|
config.tie_word_embeddings = False
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
|
|
||||||
if input_shape is None:
|
if input_shape is None:
|
||||||
@@ -890,6 +892,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
dtype = kwargs.pop("dtype", jnp.float32)
|
dtype = kwargs.pop("dtype", jnp.float32)
|
||||||
config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
||||||
|
|
||||||
|
# make sure input & output word embeddings are not tied
|
||||||
|
config.tie_word_embeddings = False
|
||||||
|
|
||||||
# init model
|
# init model
|
||||||
model = cls(config, dtype=dtype)
|
model = cls(config, dtype=dtype)
|
||||||
model.params["encoder"] = encoder.params
|
model.params["encoder"] = encoder.params
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ class FlaxEncoderDecoderMixin:
|
|||||||
enc_dec_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
|
enc_dec_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
|
||||||
|
|
||||||
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||||
|
self.assertFalse(enc_dec_model.config.tie_word_embeddings)
|
||||||
|
|
||||||
outputs_encoder_decoder = enc_dec_model(
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ class EncoderDecoderMixin:
|
|||||||
enc_dec_model.eval()
|
enc_dec_model.eval()
|
||||||
|
|
||||||
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||||
|
self.assertFalse(enc_dec_model.config.tie_word_embeddings)
|
||||||
|
|
||||||
outputs_encoder_decoder = enc_dec_model(
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
input_values=input_values,
|
input_values=input_values,
|
||||||
|
|||||||
Reference in New Issue
Block a user