Clean-up generation tests after moving methods to private (#29582)

* clean-up tests

* refine comments

* fix musicgen tests

* make style

* remove slow decorator from a test

* more clean-up

* fix other failing tests
This commit is contained in:
Raushan Turganbay
2024-03-19 22:03:31 +05:00
committed by GitHub
parent 56baa03380
commit 425ba56cdf
2 changed files with 156 additions and 1005 deletions

View File

@@ -55,8 +55,6 @@ if is_torch_available():
from transformers.generation import (
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
InfNanRemoveLogitsProcessor,
LogitsProcessorList,
)
@@ -247,19 +245,17 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return config, input_ids, attention_mask, max_length
@staticmethod
def _get_logits_processor_and_kwargs(
def _get_logits_processor_and_warper_kwargs(
input_length,
eos_token_id,
forced_bos_token_id=None,
forced_eos_token_id=None,
max_length=None,
diversity_penalty=None,
):
process_kwargs = {
"min_length": input_length + 1 if max_length is None else max_length - 1,
}
logits_processor = LogitsProcessorList()
return process_kwargs, logits_processor
warper_kwargs = {}
return process_kwargs, warper_kwargs
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
# additional post-processing in the former
@@ -269,7 +265,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
@@ -280,9 +276,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
@@ -295,7 +289,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
@@ -306,7 +300,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
@@ -316,28 +309,21 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length,
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
# check `generate()` and `sample()` are equal
output_sample, output_generate = self._sample_generate(
output_generate = self._sample_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
max_length=max_length,
num_return_sequences=3,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
)
self.assertIsInstance(output_sample, torch.Tensor)
self.assertIsInstance(output_generate, torch.Tensor)
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
@@ -349,23 +335,17 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
config.use_cache = False
model = model_class(config).to(torch_device).eval()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length,
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
output_sample, output_generate = self._sample_generate(
output_generate = self._sample_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
max_length=max_length,
num_return_sequences=1,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
output_scores=True,
@@ -374,7 +354,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True,
)
self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
def test_greedy_generate_stereo_outputs(self):
@@ -382,7 +361,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
@@ -393,7 +372,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
@@ -834,10 +812,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
# generate max 3 tokens
decoder_input_ids = inputs_dict["decoder_input_ids"]
max_length = decoder_input_ids.shape[-1] + 3
decoder_input_ids = decoder_input_ids[: batch_size * config.decoder.num_codebooks, :]
return config, input_ids, attention_mask, decoder_input_ids, max_length
max_length = 3
return config, input_ids, attention_mask, max_length
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
# different modalities -> different shapes)
@@ -846,18 +822,14 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model,
input_ids,
attention_mask,
decoder_input_ids,
max_length,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
eos_token_id=model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length,
)
@@ -876,28 +848,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
**model_kwargs,
)
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_greedy = model.greedy_search(
decoder_input_ids,
max_length=max_length,
logits_processor=logits_processor,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
encoder_outputs=encoder_outputs,
**model_kwargs,
)
return output_greedy, output_generate
return output_generate
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
# different modalities -> different shapes)
@@ -906,11 +857,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model,
input_ids,
attention_mask,
decoder_input_ids,
max_length,
num_return_sequences,
logits_processor,
logits_warper,
logits_warper_kwargs,
process_kwargs,
output_scores=False,
@@ -936,62 +884,31 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
**model_kwargs,
)
torch.manual_seed(0)
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=num_return_sequences,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# prevent flaky generation test failures
logits_processor.append(InfNanRemoveLogitsProcessor())
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_sample = model.sample(
decoder_input_ids.repeat_interleave(num_return_sequences, dim=0),
max_length=max_length,
logits_processor=logits_processor,
logits_warper=logits_warper,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
encoder_outputs=encoder_outputs,
**model_kwargs,
)
return output_sample, output_generate
return output_generate
@staticmethod
def _get_logits_processor_and_kwargs(
def _get_logits_processor_and_warper_kwargs(
input_length,
eos_token_id,
forced_bos_token_id=None,
forced_eos_token_id=None,
max_length=None,
diversity_penalty=None,
):
process_kwargs = {
"min_length": input_length + 1 if max_length is None else max_length - 1,
}
logits_processor = LogitsProcessorList()
return process_kwargs, logits_processor
warper_kwargs = {}
return process_kwargs, warper_kwargs
def test_greedy_generate_dict_outputs(self):
for model_class in self.greedy_sample_model_classes:
# disable cache
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length,
output_scores=True,
output_hidden_states=True,
@@ -999,7 +916,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate)
@@ -1007,16 +923,15 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.greedy_sample_model_classes:
# enable cache
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length,
output_scores=True,
output_hidden_states=True,
@@ -1024,64 +939,48 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_sample_generate(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length,
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
# check `generate()` and `sample()` are equal
output_sample, output_generate = self._sample_generate(
output_generate = self._sample_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length,
num_return_sequences=1,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
)
self.assertIsInstance(output_sample, torch.Tensor)
self.assertIsInstance(output_generate, torch.Tensor)
def test_sample_generate_dict_output(self):
for model_class in self.greedy_sample_model_classes:
# disable cache
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
model.config.eos_token_id,
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
max_length=max_length,
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
output_sample, output_generate = self._sample_generate(
output_generate = self._sample_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length,
num_return_sequences=3,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
output_scores=True,
@@ -1090,11 +989,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True,
)
self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_generate_without_input_ids(self):
config, _, _, _, max_length = self._get_input_ids_and_config()
config, _, _, max_length = self._get_input_ids_and_config()
# if no bos token id => cannot generate from None
if config.bos_token_id is None:
@@ -1123,15 +1021,14 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length,
output_scores=True,
output_hidden_states=True,
@@ -1139,7 +1036,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate)