fix (#13395)
This commit is contained in:
committed by
GitHub
parent
596bb85f2f
commit
efa4f5f0ea
@@ -388,8 +388,8 @@ class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_encoder_decoder_model(self, config, decoder_config):
|
def get_encoder_decoder_model(self, config, decoder_config):
|
||||||
encoder_model = Wav2Vec2Model(config)
|
encoder_model = Wav2Vec2Model(config).eval()
|
||||||
decoder_model = BertLMHeadModel(decoder_config)
|
decoder_model = BertLMHeadModel(decoder_config).eval()
|
||||||
return encoder_model, decoder_model
|
return encoder_model, decoder_model
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
@@ -439,8 +439,8 @@ class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_encoder_decoder_model(self, config, decoder_config):
|
def get_encoder_decoder_model(self, config, decoder_config):
|
||||||
encoder_model = Speech2TextEncoder(config)
|
encoder_model = Speech2TextEncoder(config).eval()
|
||||||
decoder_model = BertLMHeadModel(decoder_config)
|
decoder_model = BertLMHeadModel(decoder_config).eval()
|
||||||
return encoder_model, decoder_model
|
return encoder_model, decoder_model
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user