This commit is contained in:
Patrick von Platen
2021-09-02 18:11:26 +02:00
committed by GitHub
parent 596bb85f2f
commit efa4f5f0ea

View File

@@ -388,8 +388,8 @@ class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
)
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = Wav2Vec2Model(config)
decoder_model = BertLMHeadModel(decoder_config)
encoder_model = Wav2Vec2Model(config).eval()
decoder_model = BertLMHeadModel(decoder_config).eval()
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
@@ -439,8 +439,8 @@ class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase):
)
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = Speech2TextEncoder(config)
decoder_model = BertLMHeadModel(decoder_config)
encoder_model = Speech2TextEncoder(config).eval()
decoder_model = BertLMHeadModel(decoder_config).eval()
return encoder_model, decoder_model
def prepare_config_and_inputs(self):