Generate tests: modality-agnostic input preparation (#33685)
This commit is contained in:
@@ -282,20 +282,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
|
||||
input_name = "input_features"
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, input_ids, attention_mask, inputs_dict = GenerationTesterMixin._get_input_ids_and_config(self)
|
||||
|
||||
# `input_ids` is actually `input_features` which is a 3D tensor.
|
||||
# We must overwrite the mask to make it 2D since the original `_get_input_ids_and_config` creates an
|
||||
# attention mask of the same shape as `input_ids`.
|
||||
if len(attention_mask.shape) > 2:
|
||||
sequence_length = input_ids.shape[1]
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=attention_mask.device)
|
||||
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Speech2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Speech2TextConfig)
|
||||
@@ -632,46 +618,12 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
gen_len = (
|
||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
||||
)
|
||||
|
||||
# scores
|
||||
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
|
||||
|
||||
# Attentions
|
||||
# encoder
|
||||
self._check_encoder_attention_for_generate(
|
||||
output.encoder_attentions, batch_size, config, subsampled_seq_length
|
||||
)
|
||||
# decoder
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_attentions,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Hidden States
|
||||
# encoder
|
||||
self._check_encoder_hidden_states_for_generate(
|
||||
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
|
||||
)
|
||||
|
||||
# decoder
|
||||
self._check_hidden_states_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_hidden_states,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
|
||||
# In this model, the index of `batch_size` and `sequence_length`` in `main_input` is different: they are the
|
||||
# first two dimensions of the tensor.
|
||||
main_input = main_input[:, :, 0]
|
||||
super()._check_outputs(
|
||||
output, main_input, config, use_cache=use_cache, num_return_sequences=num_return_sequences
|
||||
)
|
||||
|
||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||
|
||||
Reference in New Issue
Block a user