Generate: consolidate output classes (#28494)

This commit is contained in:
Joao Gante
2024-01-15 17:04:08 +00:00
committed by GitHub
parent 72db39c065
commit 7e0ddf89f4
12 changed files with 177 additions and 459 deletions

View File

@@ -65,6 +65,10 @@ if is_torch_available():
DisjunctiveConstraint,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
HammingDiversityLogitsProcessor,
@@ -730,9 +734,15 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
@@ -848,9 +858,15 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_sample, SampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
else:
self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
@@ -952,9 +968,15 @@ class GenerationTesterMixin:
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
@@ -1109,9 +1131,15 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_sample, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_sample, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
@@ -1238,9 +1266,15 @@ class GenerationTesterMixin:
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_group_beam_search, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_group_beam_search, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
@@ -1390,9 +1424,15 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)