Do not remove half seq length in generation tests (#30016)

* remove seq length from generation tests

* style and quality

* [test_all] & PR suggestion

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/generation/test_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* [test all] remove unused variables

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-04-19 21:32:52 +05:00
committed by GitHub
parent b4fd49b6c5
commit b1cd48740e
10 changed files with 180 additions and 261 deletions

View File

@@ -245,34 +245,28 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
sequence_length = input_ids.shape[-1]
input_ids = input_ids[: batch_size * config.num_codebooks, :]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask, max_length
return config, input_ids, attention_mask
@staticmethod
def _get_logits_processor_and_warper_kwargs(
input_length,
forced_bos_token_id=None,
forced_eos_token_id=None,
max_length=None,
):
process_kwargs = {
"min_length": input_length + 1 if max_length is None else max_length - 1,
}
process_kwargs = {}
warper_kwargs = {}
return process_kwargs, warper_kwargs
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
@@ -1327,9 +1321,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
input_ids = input_ids[:batch_size, :]
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
# generate max 3 tokens
max_length = 3
return config, input_ids, attention_mask, max_length
return config, input_ids, attention_mask
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
# different modalities -> different shapes)
@@ -1338,29 +1330,22 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model,
input_ids,
attention_mask,
max_length,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
max_length=max_length,
)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
do_sample=False,
num_beams=1,
max_length=max_length,
max_new_tokens=self.max_new_tokens,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_process_kwargs,
**model_kwargs,
)
@@ -1373,10 +1358,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model,
input_ids,
attention_mask,
max_length,
num_return_sequences,
logits_warper_kwargs,
process_kwargs,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
@@ -1388,15 +1370,13 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
input_ids,
do_sample=True,
num_beams=1,
max_length=max_length,
max_new_tokens=self.max_new_tokens,
num_return_sequences=num_return_sequences,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**logits_warper_kwargs,
**process_kwargs,
**model_kwargs,
)
@@ -1407,25 +1387,21 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
input_length,
forced_bos_token_id=None,
forced_eos_token_id=None,
max_length=None,
):
process_kwargs = {
"min_length": input_length + 1 if max_length is None else max_length - 1,
}
process_kwargs = {}
warper_kwargs = {}
return process_kwargs, warper_kwargs
def test_greedy_generate_dict_outputs(self):
for model_class in self.greedy_sample_model_classes:
# disable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
@@ -1439,7 +1415,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.greedy_sample_model_classes:
# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
@@ -1448,7 +1424,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
@@ -1459,46 +1434,30 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_sample_generate(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config, input_ids, attention_mask = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
max_length=max_length,
)
# check `generate()` and `sample()` are equal
output_generate = self._sample_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
max_length=max_length,
num_return_sequences=1,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
)
self.assertIsInstance(output_generate, torch.Tensor)
def test_sample_generate_dict_output(self):
for model_class in self.greedy_sample_model_classes:
# disable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
max_length=max_length,
)
output_generate = self._sample_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
max_length=max_length,
num_return_sequences=3,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
@@ -1508,7 +1467,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_generate_without_input_ids(self):
config, _, _, max_length = self._get_input_ids_and_config()
config, _, _ = self._get_input_ids_and_config()
# if no bos token id => cannot generate from None
if config.bos_token_id is None:
@@ -1518,7 +1477,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model = model_class(config).to(torch_device)
model.eval()
output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True)
output_ids_generate = model.generate(
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
)
self.assertIsNotNone(output_ids_generate)
@require_torch_fp16
@@ -1537,7 +1498,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
@@ -1545,7 +1506,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,