VLM generate: tests can't generate image/video tokens (#33623)
This commit is contained in:
@@ -300,7 +300,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
@@ -1485,7 +1485,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
return output_generate
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
|
||||
@@ -303,7 +303,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
@@ -1469,7 +1469,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
|
||||
return output_generate
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
|
||||
@@ -411,9 +411,9 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
return False
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
# Overwritten from `GenerationTesterMixin`, Whisper needs `"temperature": 0.0` to be able to do beam search
|
||||
logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample)
|
||||
logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample, config=config)
|
||||
logits_processor_kwargs["temperature"] = 0.0
|
||||
return logits_processor_kwargs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user