Tests: Musicgen tests + make fix-copies (#29734)
* make fix-copies * some tests fixed * tests fixed
This commit is contained in:
@@ -1294,7 +1294,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 11. run greedy search
|
# 11. run greedy search
|
||||||
outputs = self.greedy_search(
|
outputs = self._greedy_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
@@ -1319,7 +1319,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 12. run sample
|
# 12. run sample
|
||||||
outputs = self.sample(
|
outputs = self._sample(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
logits_warper=logits_warper,
|
logits_warper=logits_warper,
|
||||||
|
|||||||
@@ -257,105 +257,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
warper_kwargs = {}
|
warper_kwargs = {}
|
||||||
return process_kwargs, warper_kwargs
|
return process_kwargs, warper_kwargs
|
||||||
|
|
||||||
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
|
|
||||||
# additional post-processing in the former
|
|
||||||
def test_greedy_generate_dict_outputs(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
# disable cache
|
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
|
||||||
config.use_cache = False
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
output_generate = self._greedy_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
max_length=max_length,
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
|
||||||
self.assertNotIn(config.pad_token_id, output_generate)
|
|
||||||
|
|
||||||
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
|
|
||||||
# additional post-processing in the former
|
|
||||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
# enable cache
|
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
|
||||||
|
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
output_generate = self._greedy_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
max_length=max_length,
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
|
||||||
|
|
||||||
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
|
|
||||||
# additional post-processing in the former
|
|
||||||
def test_sample_generate(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
|
|
||||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
max_length=max_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
# check `generate()` and `sample()` are equal
|
|
||||||
output_generate = self._sample_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
max_length=max_length,
|
|
||||||
num_return_sequences=3,
|
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
process_kwargs=process_kwargs,
|
|
||||||
)
|
|
||||||
self.assertIsInstance(output_generate, torch.Tensor)
|
|
||||||
|
|
||||||
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
|
|
||||||
# additional post-processing in the former
|
|
||||||
def test_sample_generate_dict_output(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
# disable cache
|
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
|
||||||
config.use_cache = False
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
|
|
||||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
max_length=max_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_generate = self._sample_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
max_length=max_length,
|
|
||||||
num_return_sequences=1,
|
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
process_kwargs=process_kwargs,
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
|
||||||
|
|
||||||
def test_greedy_generate_stereo_outputs(self):
|
def test_greedy_generate_stereo_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
|
|||||||
@@ -55,8 +55,6 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
from transformers.generation import (
|
from transformers.generation import (
|
||||||
GenerateDecoderOnlyOutput,
|
GenerateDecoderOnlyOutput,
|
||||||
InfNanRemoveLogitsProcessor,
|
|
||||||
LogitsProcessorList,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_torchaudio_available():
|
if is_torchaudio_available():
|
||||||
@@ -248,142 +246,24 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
return config, input_ids, attention_mask, max_length
|
return config, input_ids, attention_mask, max_length
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_logits_processor_and_kwargs(
|
def _get_logits_processor_and_warper_kwargs(
|
||||||
input_length,
|
input_length,
|
||||||
eos_token_id,
|
|
||||||
forced_bos_token_id=None,
|
forced_bos_token_id=None,
|
||||||
forced_eos_token_id=None,
|
forced_eos_token_id=None,
|
||||||
max_length=None,
|
max_length=None,
|
||||||
diversity_penalty=None,
|
|
||||||
):
|
):
|
||||||
process_kwargs = {
|
process_kwargs = {
|
||||||
"min_length": input_length + 1 if max_length is None else max_length - 1,
|
"min_length": input_length + 1 if max_length is None else max_length - 1,
|
||||||
}
|
}
|
||||||
logits_processor = LogitsProcessorList()
|
warper_kwargs = {}
|
||||||
return process_kwargs, logits_processor
|
return process_kwargs, warper_kwargs
|
||||||
|
|
||||||
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
|
|
||||||
# additional post-processing in the former
|
|
||||||
def test_greedy_generate_dict_outputs(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
# disable cache
|
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
|
||||||
config.use_cache = False
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
output_greedy, output_generate = self._greedy_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
max_length=max_length,
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
|
||||||
|
|
||||||
self.assertNotIn(config.pad_token_id, output_generate)
|
|
||||||
|
|
||||||
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
|
|
||||||
# additional post-processing in the former
|
|
||||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
# enable cache
|
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
|
||||||
|
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
output_greedy, output_generate = self._greedy_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
max_length=max_length,
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
|
||||||
|
|
||||||
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
|
|
||||||
# additional post-processing in the former
|
|
||||||
def test_sample_generate(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
|
|
||||||
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
model.config.eos_token_id,
|
|
||||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
|
||||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
|
||||||
max_length=max_length,
|
|
||||||
)
|
|
||||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
|
|
||||||
|
|
||||||
# check `generate()` and `sample()` are equal
|
|
||||||
output_sample, output_generate = self._sample_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
max_length=max_length,
|
|
||||||
num_return_sequences=3,
|
|
||||||
logits_processor=logits_processor,
|
|
||||||
logits_warper=logits_warper,
|
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
process_kwargs=process_kwargs,
|
|
||||||
)
|
|
||||||
self.assertIsInstance(output_sample, torch.Tensor)
|
|
||||||
self.assertIsInstance(output_generate, torch.Tensor)
|
|
||||||
|
|
||||||
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
|
|
||||||
# additional post-processing in the former
|
|
||||||
def test_sample_generate_dict_output(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
# disable cache
|
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
|
||||||
config.use_cache = False
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
|
|
||||||
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
model.config.eos_token_id,
|
|
||||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
|
||||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
|
||||||
max_length=max_length,
|
|
||||||
)
|
|
||||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
|
||||||
|
|
||||||
output_sample, output_generate = self._sample_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
max_length=max_length,
|
|
||||||
num_return_sequences=1,
|
|
||||||
logits_processor=logits_processor,
|
|
||||||
logits_warper=logits_warper,
|
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
process_kwargs=process_kwargs,
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
|
||||||
|
|
||||||
def test_greedy_generate_stereo_outputs(self):
|
def test_greedy_generate_stereo_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
config.audio_channels = 2
|
config.audio_channels = 2
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_greedy, output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
@@ -394,9 +274,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||||
|
|
||||||
self.assertNotIn(config.pad_token_id, output_generate)
|
self.assertNotIn(config.pad_token_id, output_generate)
|
||||||
|
|
||||||
|
|
||||||
@@ -817,10 +695,8 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||||
|
|
||||||
# generate max 3 tokens
|
# generate max 3 tokens
|
||||||
decoder_input_ids = inputs_dict["decoder_input_ids"]
|
max_length = 3
|
||||||
max_length = decoder_input_ids.shape[-1] + 3
|
return config, input_ids, attention_mask, max_length
|
||||||
decoder_input_ids = decoder_input_ids[: batch_size * config.decoder.num_codebooks, :]
|
|
||||||
return config, input_ids, attention_mask, decoder_input_ids, max_length
|
|
||||||
|
|
||||||
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are
|
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are
|
||||||
# different modalities -> different shapes)
|
# different modalities -> different shapes)
|
||||||
@@ -829,18 +705,14 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
decoder_input_ids,
|
|
||||||
max_length,
|
max_length,
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
):
|
):
|
||||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||||
input_ids.shape[-1],
|
input_ids.shape[-1],
|
||||||
eos_token_id=model.config.eos_token_id,
|
|
||||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
|
||||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -859,34 +731,17 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
return output_generate
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
|
||||||
output_greedy = model.greedy_search(
|
|
||||||
decoder_input_ids,
|
|
||||||
max_length=max_length,
|
|
||||||
logits_processor=logits_processor,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_scores=output_scores,
|
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
|
||||||
# Ignore copy
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
return output_greedy, output_generate
|
|
||||||
|
|
||||||
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are
|
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are
|
||||||
# different modalities -> different shapes)
|
# different modalities -> different shapes)
|
||||||
# Ignore copy
|
|
||||||
def _sample_generate(
|
def _sample_generate(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
decoder_input_ids,
|
|
||||||
max_length,
|
max_length,
|
||||||
num_return_sequences,
|
num_return_sequences,
|
||||||
logits_processor,
|
|
||||||
logits_warper,
|
|
||||||
logits_warper_kwargs,
|
logits_warper_kwargs,
|
||||||
process_kwargs,
|
process_kwargs,
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
@@ -912,53 +767,31 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.manual_seed(0)
|
return output_generate
|
||||||
|
|
||||||
# prevent flaky generation test failures
|
|
||||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
|
||||||
output_sample = model.sample(
|
|
||||||
decoder_input_ids.repeat_interleave(num_return_sequences, dim=0),
|
|
||||||
max_length=max_length,
|
|
||||||
logits_processor=logits_processor,
|
|
||||||
logits_warper=logits_warper,
|
|
||||||
output_scores=output_scores,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output_sample, output_generate
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_logits_processor_and_kwargs(
|
def _get_logits_processor_and_warper_kwargs(
|
||||||
input_length,
|
input_length,
|
||||||
eos_token_id,
|
|
||||||
forced_bos_token_id=None,
|
forced_bos_token_id=None,
|
||||||
forced_eos_token_id=None,
|
forced_eos_token_id=None,
|
||||||
max_length=None,
|
max_length=None,
|
||||||
diversity_penalty=None,
|
|
||||||
):
|
):
|
||||||
process_kwargs = {
|
process_kwargs = {
|
||||||
"min_length": input_length + 1 if max_length is None else max_length - 1,
|
"min_length": input_length + 1 if max_length is None else max_length - 1,
|
||||||
}
|
}
|
||||||
logits_processor = LogitsProcessorList()
|
warper_kwargs = {}
|
||||||
return process_kwargs, logits_processor
|
return process_kwargs, warper_kwargs
|
||||||
|
|
||||||
def test_greedy_generate_dict_outputs(self):
|
def test_greedy_generate_dict_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
# disable cache
|
# disable cache
|
||||||
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_greedy, output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
decoder_input_ids=decoder_input_ids,
|
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -966,7 +799,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||||
|
|
||||||
self.assertNotIn(config.pad_token_id, output_generate)
|
self.assertNotIn(config.pad_token_id, output_generate)
|
||||||
@@ -974,16 +806,15 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
# enable cache
|
# enable cache
|
||||||
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
|
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_greedy, output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
decoder_input_ids=decoder_input_ids,
|
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -991,64 +822,48 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||||
|
|
||||||
def test_sample_generate(self):
|
def test_sample_generate(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
||||||
input_ids.shape[-1],
|
input_ids.shape[-1],
|
||||||
model.config.eos_token_id,
|
|
||||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
|
||||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
|
|
||||||
|
|
||||||
# check `generate()` and `sample()` are equal
|
# check `generate()` and `sample()` are equal
|
||||||
output_sample, output_generate = self._sample_generate(
|
output_generate = self._sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
decoder_input_ids=decoder_input_ids,
|
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
num_return_sequences=1,
|
num_return_sequences=1,
|
||||||
logits_processor=logits_processor,
|
|
||||||
logits_warper=logits_warper,
|
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
logits_warper_kwargs=logits_warper_kwargs,
|
||||||
process_kwargs=process_kwargs,
|
process_kwargs=process_kwargs,
|
||||||
)
|
)
|
||||||
self.assertIsInstance(output_sample, torch.Tensor)
|
|
||||||
self.assertIsInstance(output_generate, torch.Tensor)
|
self.assertIsInstance(output_generate, torch.Tensor)
|
||||||
|
|
||||||
def test_sample_generate_dict_output(self):
|
def test_sample_generate_dict_output(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
# disable cache
|
# disable cache
|
||||||
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
||||||
input_ids.shape[-1],
|
input_ids.shape[-1],
|
||||||
model.config.eos_token_id,
|
|
||||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
|
||||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
|
||||||
|
|
||||||
output_sample, output_generate = self._sample_generate(
|
output_generate = self._sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
decoder_input_ids=decoder_input_ids,
|
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
num_return_sequences=3,
|
num_return_sequences=3,
|
||||||
logits_processor=logits_processor,
|
|
||||||
logits_warper=logits_warper,
|
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
logits_warper_kwargs=logits_warper_kwargs,
|
||||||
process_kwargs=process_kwargs,
|
process_kwargs=process_kwargs,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
@@ -1057,11 +872,10 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||||
|
|
||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
config, _, _, _, max_length = self._get_input_ids_and_config()
|
config, _, _, max_length = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# if no bos token id => cannot generate from None
|
# if no bos token id => cannot generate from None
|
||||||
if config.bos_token_id is None:
|
if config.bos_token_id is None:
|
||||||
@@ -1090,15 +904,14 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
|
|
||||||
def test_greedy_generate_stereo_outputs(self):
|
def test_greedy_generate_stereo_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
config.audio_channels = 2
|
config.audio_channels = 2
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_greedy, output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
decoder_input_ids=decoder_input_ids,
|
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -1106,7 +919,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||||
|
|
||||||
self.assertNotIn(config.pad_token_id, output_generate)
|
self.assertNotIn(config.pad_token_id, output_generate)
|
||||||
|
|||||||
Reference in New Issue
Block a user