Generation tests: update imagegpt input name, remove unused functions (#33663)

This commit is contained in:
Joao Gante
2024-09-24 16:40:48 +01:00
committed by GitHub
parent 6f7d750b73
commit a7734238ff
18 changed files with 23 additions and 656 deletions

View File

@@ -632,27 +632,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
def test_generate_without_input_ids(self):
pass
@staticmethod
def _get_encoder_outputs(
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
):
encoder = model.get_encoder()
encoder_outputs = encoder(
input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
num_interleave, dim=0
)
input_ids = input_ids[:, :, 0]
generation_config = copy.deepcopy(model.generation_config)
model._prepare_special_tokens(generation_config)
input_ids = torch.zeros_like(input_ids[:, :1]) + generation_config.decoder_start_token_id
attention_mask = None
return encoder_outputs, input_ids, attention_mask
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape[:2]
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)

View File

@@ -416,24 +416,6 @@ class TFSpeech2TextModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.T
def test_generate_without_input_ids(self):
pass
@staticmethod
def _get_encoder_outputs(
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
):
encoder = model.get_encoder()
encoder_outputs = encoder(
input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
encoder_outputs["last_hidden_state"] = tf.repeat(encoder_outputs.last_hidden_state, num_interleave, axis=0)
input_ids = input_ids[:, :, 0]
input_ids = tf.zeros_like(input_ids[:, :1], dtype=tf.int64) + model._get_decoder_start_token_id()
attention_mask = None
return encoder_outputs, input_ids, attention_mask
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape[:2]
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)