Generate: unify LogitsWarper and LogitsProcessor (#32626)

This commit is contained in:
Joao Gante
2024-08-16 11:20:41 +01:00
committed by GitHub
parent 5fd7ca7bc9
commit 70d5df6107
20 changed files with 186 additions and 623 deletions

View File

@@ -293,15 +293,9 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask
@staticmethod
def _get_logits_processor_and_warper_kwargs(
input_length,
forced_bos_token_id=None,
forced_eos_token_id=None,
):
process_kwargs = {}
warper_kwargs = {}
return process_kwargs, warper_kwargs
def _get_logits_processor_kwargs(self, do_sample=False):
logits_processor_kwargs = {}
return logits_processor_kwargs
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
@@ -1483,15 +1477,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return output_generate
@staticmethod
def _get_logits_processor_and_warper_kwargs(
input_length,
forced_bos_token_id=None,
forced_eos_token_id=None,
):
process_kwargs = {}
warper_kwargs = {}
return process_kwargs, warper_kwargs
def _get_logits_processor_kwargs(self, do_sample=False):
logits_processor_kwargs = {}
return logits_processor_kwargs
def test_greedy_generate_dict_outputs(self):
for model_class in self.greedy_sample_model_classes: