Tests: Musicgen tests + make fix-copies (#29734)
* make fix-copies * some tests fixed * tests fixed
This commit is contained in:
@@ -257,105 +257,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
warper_kwargs = {}
|
||||
return process_kwargs, warper_kwargs
|
||||
|
||||
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
|
||||
# additional post-processing in the former
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
# disable cache
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
|
||||
# additional post-processing in the former
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
# enable cache
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
|
||||
# additional post-processing in the former
|
||||
def test_sample_generate(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
# check `generate()` and `sample()` are equal
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
max_length=max_length,
|
||||
num_return_sequences=3,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
process_kwargs=process_kwargs,
|
||||
)
|
||||
self.assertIsInstance(output_generate, torch.Tensor)
|
||||
|
||||
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
|
||||
# additional post-processing in the former
|
||||
def test_sample_generate_dict_output(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
# disable cache
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
max_length=max_length,
|
||||
num_return_sequences=1,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
process_kwargs=process_kwargs,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
Reference in New Issue
Block a user