Generate: consolidate output classes (#28494)

This commit is contained in:
Joao Gante
2024-01-15 17:04:08 +00:00
committed by GitHub
parent 72db39c065
commit 7e0ddf89f4
12 changed files with 177 additions and 459 deletions

View File

@@ -65,6 +65,10 @@ if is_torch_available():
DisjunctiveConstraint,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
HammingDiversityLogitsProcessor,
@@ -730,9 +734,15 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
@@ -848,9 +858,15 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_sample, SampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
else:
self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
@@ -952,9 +968,15 @@ class GenerationTesterMixin:
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
@@ -1109,9 +1131,15 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_sample, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_sample, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
@@ -1238,9 +1266,15 @@ class GenerationTesterMixin:
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_group_beam_search, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_group_beam_search, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
@@ -1390,9 +1424,15 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)

View File

@@ -53,12 +53,10 @@ if is_torch_available():
set_seed,
)
from transformers.generation import (
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
InfNanRemoveLogitsProcessor,
LogitsProcessorList,
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
)
@@ -282,8 +280,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
@@ -308,8 +306,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
# additional post-processing in the former
@@ -376,8 +374,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True,
)
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
@@ -395,8 +393,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
@@ -1001,8 +999,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate)
@@ -1026,8 +1024,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_sample_generate(self):
for model_class in self.greedy_sample_model_classes:
@@ -1092,8 +1090,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True,
)
self.assertIsInstance(output_sample, SampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_generate_without_input_ids(self):
config, _, _, _, max_length = self._get_input_ids_and_config()
@@ -1141,8 +1139,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate)