[Doctests] Fix ignore bug and add more doc tests (#15911)

* finish speech doc tests

* finish

* boom

* Update src/transformers/models/speech_to_text/modeling_speech_to_text.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2022-03-03 16:01:56 +01:00
committed by GitHub
parent b693cbf99c
commit 6cbfa7bf4c
10 changed files with 115 additions and 74 deletions

View File

@@ -185,6 +185,17 @@ class Speech2TextModelTester:
return input_lengths
def create_and_check_model_forward(self, config, inputs_dict):
model = Speech2TextModel(config=config).to(torch_device).eval()
input_features = inputs_dict["input_features"]
decoder_input_ids = inputs_dict["decoder_input_ids"]
# first forward pass
last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
self.parent.assertTrue(last_hidden_state.shape, (13, 7, 16))
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = Speech2TextModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["decoder_input_ids"]
@@ -284,6 +295,10 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], [])
def test_model_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)