Generate: consolidate output classes (#28494)
This commit is contained in:
@@ -53,12 +53,10 @@ if is_torch_available():
|
||||
set_seed,
|
||||
)
|
||||
from transformers.generation import (
|
||||
GreedySearchDecoderOnlyOutput,
|
||||
GreedySearchEncoderDecoderOutput,
|
||||
GenerateDecoderOnlyOutput,
|
||||
GenerateEncoderDecoderOutput,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
SampleDecoderOnlyOutput,
|
||||
SampleEncoderDecoderOutput,
|
||||
)
|
||||
|
||||
|
||||
@@ -282,8 +280,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
@@ -308,8 +306,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
|
||||
# additional post-processing in the former
|
||||
@@ -376,8 +374,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
@@ -395,8 +393,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
@@ -1001,8 +999,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
@@ -1026,8 +1024,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
def test_sample_generate(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
@@ -1092,8 +1090,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_sample, SampleEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
def test_generate_without_input_ids(self):
|
||||
config, _, _, _, max_length = self._get_input_ids_and_config()
|
||||
@@ -1141,8 +1139,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user