Generate can return cross-attention weights too (#10493)
This commit is contained in:
committed by
GitHub
parent
b013842244
commit
1750e62900
@@ -96,6 +96,9 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
|
|||||||
decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||||
|
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||||
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||||
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, generated_length, hidden_size)`.
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, generated_length, hidden_size)`.
|
||||||
@@ -106,6 +109,7 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
|
|||||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -164,6 +168,9 @@ class SampleEncoderDecoderOutput(ModelOutput):
|
|||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, num_heads, generated_length,
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, num_heads, generated_length,
|
||||||
sequence_length)`.
|
sequence_length)`.
|
||||||
|
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||||
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||||
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, generated_length, hidden_size)`.
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, generated_length, hidden_size)`.
|
||||||
@@ -174,6 +181,7 @@ class SampleEncoderDecoderOutput(ModelOutput):
|
|||||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -239,6 +247,9 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
|||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, num_heads,
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, num_heads,
|
||||||
generated_length, sequence_length)`.
|
generated_length, sequence_length)`.
|
||||||
|
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||||
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||||
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length,
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length,
|
||||||
@@ -251,6 +262,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
|||||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -314,6 +326,9 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
|||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length,
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length,
|
||||||
sequence_length)`.
|
sequence_length)`.
|
||||||
|
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||||
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||||
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`.
|
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`.
|
||||||
@@ -325,6 +340,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
|||||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -1177,6 +1193,7 @@ class GenerationMixin:
|
|||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||||
|
|
||||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||||
@@ -1212,6 +1229,8 @@ class GenerationMixin:
|
|||||||
decoder_attentions += (
|
decoder_attentions += (
|
||||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||||
)
|
)
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
cross_attentions += (outputs.cross_attentions,)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
decoder_hidden_states += (
|
decoder_hidden_states += (
|
||||||
@@ -1260,6 +1279,7 @@ class GenerationMixin:
|
|||||||
encoder_attentions=encoder_attentions,
|
encoder_attentions=encoder_attentions,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
|
cross_attentions=cross_attentions,
|
||||||
decoder_hidden_states=decoder_hidden_states,
|
decoder_hidden_states=decoder_hidden_states,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1384,6 +1404,7 @@ class GenerationMixin:
|
|||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||||
|
|
||||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||||
@@ -1424,6 +1445,8 @@ class GenerationMixin:
|
|||||||
decoder_attentions += (
|
decoder_attentions += (
|
||||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||||
)
|
)
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
cross_attentions += (outputs.cross_attentions,)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
decoder_hidden_states += (
|
decoder_hidden_states += (
|
||||||
@@ -1468,6 +1491,7 @@ class GenerationMixin:
|
|||||||
encoder_attentions=encoder_attentions,
|
encoder_attentions=encoder_attentions,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
|
cross_attentions=cross_attentions,
|
||||||
decoder_hidden_states=decoder_hidden_states,
|
decoder_hidden_states=decoder_hidden_states,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1604,6 +1628,7 @@ class GenerationMixin:
|
|||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||||
|
|
||||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||||
@@ -1656,6 +1681,8 @@ class GenerationMixin:
|
|||||||
decoder_attentions += (
|
decoder_attentions += (
|
||||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||||
)
|
)
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
cross_attentions += (outputs.cross_attentions,)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
decoder_hidden_states += (
|
decoder_hidden_states += (
|
||||||
@@ -1716,6 +1743,7 @@ class GenerationMixin:
|
|||||||
encoder_attentions=encoder_attentions,
|
encoder_attentions=encoder_attentions,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
|
cross_attentions=cross_attentions,
|
||||||
decoder_hidden_states=decoder_hidden_states,
|
decoder_hidden_states=decoder_hidden_states,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1865,6 +1893,7 @@ class GenerationMixin:
|
|||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||||
|
|
||||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||||
@@ -1913,6 +1942,8 @@ class GenerationMixin:
|
|||||||
decoder_attentions += (
|
decoder_attentions += (
|
||||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||||
)
|
)
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
cross_attentions += (outputs.cross_attentions,)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
decoder_hidden_states += (
|
decoder_hidden_states += (
|
||||||
@@ -1968,17 +1999,18 @@ class GenerationMixin:
|
|||||||
if not output_scores:
|
if not output_scores:
|
||||||
sequence_outputs["sequence_scores"] = None
|
sequence_outputs["sequence_scores"] = None
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
return BeamSearchEncoderDecoderOutput(
|
return BeamSampleEncoderDecoderOutput(
|
||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequence_outputs["sequences"],
|
||||||
sequences_scores=sequence_outputs["sequence_scores"],
|
sequences_scores=sequence_outputs["sequence_scores"],
|
||||||
scores=scores,
|
scores=scores,
|
||||||
encoder_attentions=encoder_attentions,
|
encoder_attentions=encoder_attentions,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
|
cross_attentions=cross_attentions,
|
||||||
decoder_hidden_states=decoder_hidden_states,
|
decoder_hidden_states=decoder_hidden_states,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return BeamSearchDecoderOnlyOutput(
|
return BeamSampleDecoderOnlyOutput(
|
||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequence_outputs["sequences"],
|
||||||
sequences_scores=sequence_outputs["sequence_scores"],
|
sequences_scores=sequence_outputs["sequence_scores"],
|
||||||
scores=scores,
|
scores=scores,
|
||||||
@@ -2115,6 +2147,7 @@ class GenerationMixin:
|
|||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||||
|
|
||||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||||
@@ -2238,6 +2271,8 @@ class GenerationMixin:
|
|||||||
decoder_attentions += (
|
decoder_attentions += (
|
||||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||||
)
|
)
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
cross_attentions += (outputs.cross_attentions,)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
decoder_hidden_states += (
|
decoder_hidden_states += (
|
||||||
@@ -2263,7 +2298,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if not output_scores:
|
if not output_scores:
|
||||||
sequence_outputs["sequence_scores"]
|
sequence_outputs["sequence_scores"] = None
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
return BeamSearchEncoderDecoderOutput(
|
return BeamSearchEncoderDecoderOutput(
|
||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequence_outputs["sequences"],
|
||||||
@@ -2272,6 +2307,7 @@ class GenerationMixin:
|
|||||||
encoder_attentions=encoder_attentions,
|
encoder_attentions=encoder_attentions,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
|
cross_attentions=cross_attentions,
|
||||||
decoder_hidden_states=decoder_hidden_states,
|
decoder_hidden_states=decoder_hidden_states,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ if is_torch_available():
|
|||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
)
|
)
|
||||||
from transformers.generation_utils import (
|
from transformers.generation_utils import (
|
||||||
|
BeamSampleDecoderOnlyOutput,
|
||||||
|
BeamSampleEncoderDecoderOutput,
|
||||||
BeamSearchDecoderOnlyOutput,
|
BeamSearchDecoderOnlyOutput,
|
||||||
BeamSearchEncoderDecoderOutput,
|
BeamSearchEncoderDecoderOutput,
|
||||||
GreedySearchDecoderOnlyOutput,
|
GreedySearchDecoderOnlyOutput,
|
||||||
@@ -900,11 +902,11 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertIsInstance(output_beam_sample, BeamSearchEncoderDecoderOutput)
|
self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
|
||||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
||||||
else:
|
else:
|
||||||
self.assertIsInstance(output_beam_sample, BeamSearchDecoderOnlyOutput)
|
self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
|
||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
||||||
|
|
||||||
self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist())
|
self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist())
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
|||||||
Reference in New Issue
Block a user