Generate can return cross-attention weights too (#10493)
This commit is contained in:
committed by
GitHub
parent
b013842244
commit
1750e62900
@@ -39,6 +39,8 @@ if is_torch_available():
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
from transformers.generation_utils import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
BeamSearchDecoderOnlyOutput,
|
||||
BeamSearchEncoderDecoderOutput,
|
||||
GreedySearchDecoderOnlyOutput,
|
||||
@@ -900,11 +902,11 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_beam_sample, BeamSearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_beam_sample, BeamSearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
||||
|
||||
self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist())
|
||||
self.assertTrue(
|
||||
|
||||
Reference in New Issue
Block a user