Generate: consolidate output classes (#28494)
This commit is contained in:
@@ -65,6 +65,10 @@ if is_torch_available():
|
||||
DisjunctiveConstraint,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
GenerateBeamDecoderOnlyOutput,
|
||||
GenerateBeamEncoderDecoderOutput,
|
||||
GenerateDecoderOnlyOutput,
|
||||
GenerateEncoderDecoderOutput,
|
||||
GreedySearchDecoderOnlyOutput,
|
||||
GreedySearchEncoderDecoderOutput,
|
||||
HammingDiversityLogitsProcessor,
|
||||
@@ -730,9 +734,15 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||
|
||||
@@ -848,9 +858,15 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_sample, SampleEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
||||
|
||||
@@ -952,9 +968,15 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
@@ -1109,9 +1131,15 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_beam_sample, GenerateBeamEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_beam_sample, GenerateBeamDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
||||
|
||||
@@ -1238,9 +1266,15 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_group_beam_search, GenerateBeamEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_group_beam_search, GenerateBeamDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
@@ -1390,9 +1424,15 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user