Generate tests: modality-agnostic input preparation (#33685)

This commit is contained in:
Joao Gante
2024-10-03 14:01:24 +01:00
committed by GitHub
parent f2bf4fcf3d
commit d29738f5b4
34 changed files with 241 additions and 906 deletions

View File

@@ -60,10 +60,6 @@ if is_torch_available():
MusicgenModel,
set_seed,
)
from transformers.generation import (
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
)
def _config_zero_init(config):
@@ -124,6 +120,7 @@ class MusicgenDecoderTester:
pad_token_id=99,
bos_token_id=99,
num_codebooks=4,
audio_channels=1,
):
self.parent = parent
self.batch_size = batch_size
@@ -141,6 +138,7 @@ class MusicgenDecoderTester:
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.num_codebooks = num_codebooks
self.audio_channels = audio_channels
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size)
@@ -166,6 +164,7 @@ class MusicgenDecoderTester:
bos_token_id=self.bos_token_id,
num_codebooks=self.num_codebooks,
tie_word_embeddings=False,
audio_channels=self.audio_channels,
)
return config
@@ -282,47 +281,15 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
def test_tied_weights_keys(self):
pass
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
_ = inputs_dict.pop("attention_mask", None)
inputs_dict = {
k: v[:batch_size, ...]
for k, v in inputs_dict.items()
if "head_mask" not in k and isinstance(v, torch.Tensor)
}
# take max batch_size
sequence_length = input_ids.shape[-1]
input_ids = input_ids[: batch_size * config.num_codebooks, :]
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask, inputs_dict
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
inputs_dict={},
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
original_audio_channels = self.model_tester.audio_channels
self.model_tester.audio_channels = 2
super().test_greedy_generate_dict_outputs()
self.model_tester.audio_channels = original_audio_channels
@require_flash_attn
@require_torch_gpu
@@ -998,6 +965,7 @@ class MusicgenTester:
num_codebooks=4,
num_filters=4,
codebook_size=128,
audio_channels=1,
):
self.parent = parent
self.batch_size = batch_size
@@ -1017,6 +985,7 @@ class MusicgenTester:
self.num_codebooks = num_codebooks
self.num_filters = num_filters
self.codebook_size = codebook_size
self.audio_channels = audio_channels
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -1052,6 +1021,7 @@ class MusicgenTester:
bos_token_id=self.bos_token_id,
num_codebooks=self.num_codebooks,
tie_word_embeddings=False,
audio_channels=self.audio_channels,
)
config = MusicgenConfig.from_sub_models_config(text_encoder_config, audio_encoder_config, decoder_config)
return config
@@ -1415,170 +1385,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
lm_heads = model.get_output_embeddings()
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
# take max batch_size
sequence_length = input_ids.shape[-1]
input_ids = input_ids[:batch_size, :]
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
# different modalities -> different shapes)
def _greedy_generate(
self,
model,
input_ids,
attention_mask,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
do_sample=False,
num_beams=1,
max_new_tokens=self.max_new_tokens,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**model_kwargs,
)
return output_generate
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
# different modalities -> different shapes)
def _sample_generate(
self,
model,
input_ids,
attention_mask,
num_return_sequences,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
torch.manual_seed(0)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
do_sample=True,
num_beams=1,
max_new_tokens=self.max_new_tokens,
num_return_sequences=num_return_sequences,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**model_kwargs,
)
return output_generate
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs
def test_greedy_generate_dict_outputs(self):
for model_class in self.greedy_sample_model_classes:
# disable cache
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate)
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.greedy_sample_model_classes:
# enable cache
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_sample_generate(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
# check `generate()` and `sample()` are equal
output_generate = self._sample_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
num_return_sequences=1,
)
self.assertIsInstance(output_generate, torch.Tensor)
def test_sample_generate_dict_output(self):
for model_class in self.greedy_sample_model_classes:
# disable cache
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_generate = self._sample_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
num_return_sequences=3,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_generate_without_input_ids(self):
config, _, _ = self._get_input_ids_and_config()
# if no bos token id => cannot generate from None
if config.bos_token_id is None:
self.skipTest(reason="bos_token_id is None")
for model_class in self.greedy_sample_model_classes:
model = model_class(config).to(torch_device)
model.eval()
output_ids_generate = model.generate(
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
)
self.assertIsNotNone(output_ids_generate)
@require_torch_fp16
@require_torch_accelerator # not all operations are supported in fp16 on CPU
def test_generate_fp16(self):
@@ -1595,24 +1405,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
)
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate)
original_audio_channels = self.model_tester.audio_channels
self.model_tester.audio_channels = 2
super().test_greedy_generate_dict_outputs()
self.model_tester.audio_channels = original_audio_channels
@unittest.skip(
reason="MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model"