Generate: consolidate output classes (#28494)
This commit is contained in:
@@ -65,6 +65,10 @@ if is_torch_available():
|
||||
DisjunctiveConstraint,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
GenerateBeamDecoderOnlyOutput,
|
||||
GenerateBeamEncoderDecoderOutput,
|
||||
GenerateDecoderOnlyOutput,
|
||||
GenerateEncoderDecoderOutput,
|
||||
GreedySearchDecoderOnlyOutput,
|
||||
GreedySearchEncoderDecoderOutput,
|
||||
HammingDiversityLogitsProcessor,
|
||||
@@ -730,9 +734,15 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||
|
||||
@@ -848,9 +858,15 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_sample, SampleEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
||||
|
||||
@@ -952,9 +968,15 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
@@ -1109,9 +1131,15 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_beam_sample, GenerateBeamEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_beam_sample, GenerateBeamDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
||||
|
||||
@@ -1238,9 +1266,15 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_group_beam_search, GenerateBeamEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_group_beam_search, GenerateBeamDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
@@ -1390,9 +1424,15 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
|
||||
@@ -53,12 +53,10 @@ if is_torch_available():
|
||||
set_seed,
|
||||
)
|
||||
from transformers.generation import (
|
||||
GreedySearchDecoderOnlyOutput,
|
||||
GreedySearchEncoderDecoderOutput,
|
||||
GenerateDecoderOnlyOutput,
|
||||
GenerateEncoderDecoderOutput,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
SampleDecoderOnlyOutput,
|
||||
SampleEncoderDecoderOutput,
|
||||
)
|
||||
|
||||
|
||||
@@ -282,8 +280,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
@@ -308,8 +306,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
|
||||
# additional post-processing in the former
|
||||
@@ -376,8 +374,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
@@ -395,8 +393,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
@@ -1001,8 +999,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
@@ -1026,8 +1024,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
def test_sample_generate(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
@@ -1092,8 +1090,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_sample, SampleEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
def test_generate_without_input_ids(self):
|
||||
config, _, _, _, max_length = self._get_input_ids_and_config()
|
||||
@@ -1141,8 +1139,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user