Generate: unify LogitsWarper and LogitsProcessor (#32626)
This commit is contained in:
@@ -293,15 +293,9 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||
return config, input_ids, attention_mask
|
||||
|
||||
@staticmethod
|
||||
def _get_logits_processor_and_warper_kwargs(
|
||||
input_length,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
):
|
||||
process_kwargs = {}
|
||||
warper_kwargs = {}
|
||||
return process_kwargs, warper_kwargs
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
@@ -1483,15 +1477,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
return output_generate
|
||||
|
||||
@staticmethod
|
||||
def _get_logits_processor_and_warper_kwargs(
|
||||
input_length,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
):
|
||||
process_kwargs = {}
|
||||
warper_kwargs = {}
|
||||
return process_kwargs, warper_kwargs
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
|
||||
Reference in New Issue
Block a user