diff --git a/tests/test_modeling_speech_encoder_decoder.py b/tests/test_modeling_speech_encoder_decoder.py index 5c42f6ec8c..453adbf73b 100644 --- a/tests/test_modeling_speech_encoder_decoder.py +++ b/tests/test_modeling_speech_encoder_decoder.py @@ -388,8 +388,8 @@ class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase): ) def get_encoder_decoder_model(self, config, decoder_config): - encoder_model = Wav2Vec2Model(config) - decoder_model = BertLMHeadModel(decoder_config) + encoder_model = Wav2Vec2Model(config).eval() + decoder_model = BertLMHeadModel(decoder_config).eval() return encoder_model, decoder_model def prepare_config_and_inputs(self): @@ -439,8 +439,8 @@ class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase): ) def get_encoder_decoder_model(self, config, decoder_config): - encoder_model = Speech2TextEncoder(config) - decoder_model = BertLMHeadModel(decoder_config) + encoder_model = Speech2TextEncoder(config).eval() + decoder_model = BertLMHeadModel(decoder_config).eval() return encoder_model, decoder_model def prepare_config_and_inputs(self):