Generation tests: update imagegpt input name, remove unused functions (#33663)
This commit is contained in:
@@ -868,26 +868,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _get_encoder_outputs(
|
||||
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
|
||||
):
|
||||
encoder = model.get_encoder()
|
||||
encoder_outputs = encoder(
|
||||
input_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
||||
num_interleave, dim=0
|
||||
)
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
model._prepare_special_tokens(generation_config)
|
||||
input_ids = input_ids[:, :, 0]
|
||||
input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + generation_config.decoder_start_token_id
|
||||
attention_mask = None
|
||||
return encoder_outputs, input_ids, attention_mask
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, mel, seq_length = input_ids.shape
|
||||
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
||||
@@ -3894,13 +3874,6 @@ class WhisperStandaloneDecoderModelTester:
|
||||
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
config, input_features = self.prepare_config_and_inputs()
|
||||
input_ids = input_features["input_ids"]
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.decoder_seq_length, self.hidden_size])
|
||||
|
||||
return (config, input_ids, encoder_hidden_states)
|
||||
|
||||
def create_and_check_decoder_model_past(self, config, input_ids):
|
||||
config.use_cache = True
|
||||
model = WhisperDecoder(config=config).to(torch_device).eval()
|
||||
|
||||
Reference in New Issue
Block a user