Generate tests: modality-agnostic input preparation (#33685)
This commit is contained in:
@@ -60,10 +60,6 @@ if is_torch_available():
|
||||
MusicgenModel,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.generation import (
|
||||
GenerateDecoderOnlyOutput,
|
||||
GenerateEncoderDecoderOutput,
|
||||
)
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
@@ -124,6 +120,7 @@ class MusicgenDecoderTester:
|
||||
pad_token_id=99,
|
||||
bos_token_id=99,
|
||||
num_codebooks=4,
|
||||
audio_channels=1,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -141,6 +138,7 @@ class MusicgenDecoderTester:
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.num_codebooks = num_codebooks
|
||||
self.audio_channels = audio_channels
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size)
|
||||
@@ -166,6 +164,7 @@ class MusicgenDecoderTester:
|
||||
bos_token_id=self.bos_token_id,
|
||||
num_codebooks=self.num_codebooks,
|
||||
tie_word_embeddings=False,
|
||||
audio_channels=self.audio_channels,
|
||||
)
|
||||
return config
|
||||
|
||||
@@ -282,47 +281,15 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
def test_tied_weights_keys(self):
|
||||
pass
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
_ = inputs_dict.pop("attention_mask", None)
|
||||
inputs_dict = {
|
||||
k: v[:batch_size, ...]
|
||||
for k, v in inputs_dict.items()
|
||||
if "head_mask" not in k and isinstance(v, torch.Tensor)
|
||||
}
|
||||
|
||||
# take max batch_size
|
||||
sequence_length = input_ids.shape[-1]
|
||||
input_ids = input_ids[: batch_size * config.num_codebooks, :]
|
||||
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
config.audio_channels = 2
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
inputs_dict={},
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
original_audio_channels = self.model_tester.audio_channels
|
||||
self.model_tester.audio_channels = 2
|
||||
super().test_greedy_generate_dict_outputs()
|
||||
self.model_tester.audio_channels = original_audio_channels
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@@ -998,6 +965,7 @@ class MusicgenTester:
|
||||
num_codebooks=4,
|
||||
num_filters=4,
|
||||
codebook_size=128,
|
||||
audio_channels=1,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -1017,6 +985,7 @@ class MusicgenTester:
|
||||
self.num_codebooks = num_codebooks
|
||||
self.num_filters = num_filters
|
||||
self.codebook_size = codebook_size
|
||||
self.audio_channels = audio_channels
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -1052,6 +1021,7 @@ class MusicgenTester:
|
||||
bos_token_id=self.bos_token_id,
|
||||
num_codebooks=self.num_codebooks,
|
||||
tie_word_embeddings=False,
|
||||
audio_channels=self.audio_channels,
|
||||
)
|
||||
config = MusicgenConfig.from_sub_models_config(text_encoder_config, audio_encoder_config, decoder_config)
|
||||
return config
|
||||
@@ -1415,170 +1385,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
lm_heads = model.get_output_embeddings()
|
||||
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
# take max batch_size
|
||||
sequence_length = input_ids.shape[-1]
|
||||
input_ids = input_ids[:batch_size, :]
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||
|
||||
return config, input_ids, attention_mask
|
||||
|
||||
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
|
||||
# different modalities -> different shapes)
|
||||
def _greedy_generate(
|
||||
self,
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
output_scores=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return output_generate
|
||||
|
||||
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
|
||||
# different modalities -> different shapes)
|
||||
def _sample_generate(
|
||||
self,
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
num_return_sequences,
|
||||
output_scores=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
num_return_sequences=num_return_sequences,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return output_generate
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
# disable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
# enable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
def test_sample_generate(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
# check `generate()` and `sample()` are equal
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
num_return_sequences=1,
|
||||
)
|
||||
self.assertIsInstance(output_generate, torch.Tensor)
|
||||
|
||||
def test_sample_generate_dict_output(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
# disable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
num_return_sequences=3,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
def test_generate_without_input_ids(self):
|
||||
config, _, _ = self._get_input_ids_and_config()
|
||||
|
||||
# if no bos token id => cannot generate from None
|
||||
if config.bos_token_id is None:
|
||||
self.skipTest(reason="bos_token_id is None")
|
||||
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output_ids_generate = model.generate(
|
||||
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
|
||||
)
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
@require_torch_fp16
|
||||
@require_torch_accelerator # not all operations are supported in fp16 on CPU
|
||||
def test_generate_fp16(self):
|
||||
@@ -1595,24 +1405,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
)
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.audio_channels = 2
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
original_audio_channels = self.model_tester.audio_channels
|
||||
self.model_tester.audio_channels = 2
|
||||
super().test_greedy_generate_dict_outputs()
|
||||
self.model_tester.audio_channels = original_audio_channels
|
||||
|
||||
@unittest.skip(
|
||||
reason="MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model"
|
||||
|
||||
Reference in New Issue
Block a user