Generate can return cross-attention weights too (#10493)

This commit is contained in:
Mehrad Moradshahi
2021-03-03 00:27:02 -08:00
committed by GitHub
parent b013842244
commit 1750e62900
2 changed files with 45 additions and 7 deletions

View File

@@ -39,6 +39,8 @@ if is_torch_available():
TopPLogitsWarper,
)
from transformers.generation_utils import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput,
GreedySearchDecoderOnlyOutput,
@@ -900,11 +902,11 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_sample, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_sample, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist())
self.assertTrue(