[Doctests] Fix ignore bug and add more doc tests (#15911)
* finish speech doc tests * finish * boom * Update src/transformers/models/speech_to_text/modeling_speech_to_text.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
b693cbf99c
commit
6cbfa7bf4c
@@ -185,6 +185,17 @@ class Speech2TextModelTester:
|
||||
|
||||
return input_lengths
|
||||
|
||||
def create_and_check_model_forward(self, config, inputs_dict):
|
||||
model = Speech2TextModel(config=config).to(torch_device).eval()
|
||||
|
||||
input_features = inputs_dict["input_features"]
|
||||
decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
|
||||
# first forward pass
|
||||
last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
|
||||
|
||||
self.parent.assertTrue(last_hidden_state.shape, (13, 7, 16))
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||||
model = Speech2TextModel(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["decoder_input_ids"]
|
||||
@@ -284,6 +295,10 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
|
||||
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
||||
self.assertEqual(info["missing_keys"], [])
|
||||
|
||||
def test_model_forward(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user