diff --git a/tests/test_modeling_flax_encoder_decoder.py b/tests/test_modeling_flax_encoder_decoder.py index 3d0311cf85..b23d8ffcf1 100644 --- a/tests/test_modeling_flax_encoder_decoder.py +++ b/tests/test_modeling_flax_encoder_decoder.py @@ -357,8 +357,11 @@ class FlaxEncoderDecoderModelTest(unittest.TestCase): return config def _check_configuration_tie(self, model): - assert id(model.decoder.config) == id(model.config.decoder) - assert id(model.encoder.config) == id(model.config.encoder) + + module = model.module.bind(model.params) + + assert id(module.decoder.config) == id(model.config.decoder) + assert id(module.encoder.config) == id(model.config.encoder) @slow def test_configuration_tie(self):