Generate: return past_key_values (#25086)
This commit is contained in:
@@ -104,12 +104,20 @@ class GreedySearchDecoderOnlyOutput(ModelOutput):
|
|||||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
|
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||||
|
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||||
|
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||||
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||||
|
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||||
|
encoder_sequence_length, embed_size_per_head)`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -140,6 +148,13 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
|
|||||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
|
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||||
|
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||||
|
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||||
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||||
|
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||||
|
encoder_sequence_length, embed_size_per_head)`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
@@ -149,6 +164,7 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
|
|||||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -169,15 +185,23 @@ class ContrastiveSearchDecoderOnlyOutput(ModelOutput):
|
|||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
||||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is
|
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is
|
||||||
passed or when `config.output_hidden_states=True`):
|
passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
(one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length,
|
||||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
hidden_size)`.
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
|
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||||
|
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||||
|
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||||
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||||
|
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||||
|
encoder_sequence_length, embed_size_per_head)`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -211,6 +235,13 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
|
|||||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
|
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||||
|
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||||
|
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||||
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||||
|
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||||
|
encoder_sequence_length, embed_size_per_head)`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
@@ -220,6 +251,7 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
|
|||||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -243,12 +275,20 @@ class SampleDecoderOnlyOutput(ModelOutput):
|
|||||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`.
|
`torch.FloatTensor` of shape `(num_return_sequences*batch_size, generated_length, hidden_size)`.
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
|
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||||
|
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||||
|
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||||
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||||
|
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||||
|
encoder_sequence_length, embed_size_per_head)`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -283,6 +323,13 @@ class SampleEncoderDecoderOutput(ModelOutput):
|
|||||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`.
|
`torch.FloatTensor` of shape `(batch_size*num_return_sequences, generated_length, hidden_size)`.
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
|
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||||
|
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||||
|
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||||
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||||
|
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||||
|
encoder_sequence_length, embed_size_per_head)`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
@@ -292,6 +339,7 @@ class SampleEncoderDecoderOutput(ModelOutput):
|
|||||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -319,6 +367,13 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
|
|||||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
|
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
|
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||||
|
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||||
|
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||||
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||||
|
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||||
|
encoder_sequence_length, embed_size_per_head)`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
@@ -327,6 +382,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
|
|||||||
beam_indices: Optional[torch.LongTensor] = None
|
beam_indices: Optional[torch.LongTensor] = None
|
||||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -366,6 +422,13 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
|||||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
|
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
|
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||||
|
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||||
|
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||||
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||||
|
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||||
|
encoder_sequence_length, embed_size_per_head)`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
@@ -377,6 +440,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
|||||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -404,6 +468,13 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
|
|||||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
|
`torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
|
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||||
|
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||||
|
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||||
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||||
|
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||||
|
encoder_sequence_length, embed_size_per_head)`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
@@ -412,6 +483,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
|
|||||||
beam_indices: Optional[torch.LongTensor] = None
|
beam_indices: Optional[torch.LongTensor] = None
|
||||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -450,6 +522,13 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
|||||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
|
`torch.FloatTensor` of shape `(batch_size*num_beams, generated_length, hidden_size)`.
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
|
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
||||||
|
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
||||||
|
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||||
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||||
|
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
||||||
|
encoder_sequence_length, embed_size_per_head)`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
@@ -461,6 +540,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
|||||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||||
|
|
||||||
|
|
||||||
GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
|
GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
|
||||||
@@ -2148,8 +2228,8 @@ class GenerationMixin:
|
|||||||
items.append(item.repeat_interleave(1, dim=0))
|
items.append(item.repeat_interleave(1, dim=0))
|
||||||
else:
|
else:
|
||||||
items.append(item.repeat_interleave(top_k, dim=0))
|
items.append(item.repeat_interleave(top_k, dim=0))
|
||||||
new_key_values.append(items)
|
new_key_values.append(tuple(items))
|
||||||
model_kwargs["past_key_values"] = new_key_values
|
model_kwargs["past_key_values"] = tuple(new_key_values)
|
||||||
|
|
||||||
if sequential:
|
if sequential:
|
||||||
all_outputs = {key: [] for key in outputs} # defined in first loop iteration
|
all_outputs = {key: [] for key in outputs} # defined in first loop iteration
|
||||||
@@ -2330,6 +2410,17 @@ class GenerationMixin:
|
|||||||
streamer.end()
|
streamer.end()
|
||||||
|
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
|
# Contrastive search works by forward looking at the next token, so we need to exclude it from
|
||||||
|
# `past_key_values` to be consistent with the other decoding methods
|
||||||
|
if model_kwargs.get("past_key_values") is not None:
|
||||||
|
past_key_values = []
|
||||||
|
for layer in model_kwargs["past_key_values"]:
|
||||||
|
layer_past_key_values = []
|
||||||
|
for item in layer:
|
||||||
|
layer_past_key_values.append(item[..., :-1, :])
|
||||||
|
past_key_values.append(tuple(layer_past_key_values))
|
||||||
|
model_kwargs["past_key_values"] = tuple(past_key_values)
|
||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
return ContrastiveSearchEncoderDecoderOutput(
|
return ContrastiveSearchEncoderDecoderOutput(
|
||||||
sequences=input_ids,
|
sequences=input_ids,
|
||||||
@@ -2339,6 +2430,7 @@ class GenerationMixin:
|
|||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
cross_attentions=cross_attentions,
|
cross_attentions=cross_attentions,
|
||||||
decoder_hidden_states=decoder_hidden_states,
|
decoder_hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return ContrastiveSearchDecoderOnlyOutput(
|
return ContrastiveSearchDecoderOnlyOutput(
|
||||||
@@ -2346,6 +2438,7 @@ class GenerationMixin:
|
|||||||
scores=scores,
|
scores=scores,
|
||||||
attentions=decoder_attentions,
|
attentions=decoder_attentions,
|
||||||
hidden_states=decoder_hidden_states,
|
hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return input_ids
|
return input_ids
|
||||||
@@ -2598,6 +2691,7 @@ class GenerationMixin:
|
|||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
cross_attentions=cross_attentions,
|
cross_attentions=cross_attentions,
|
||||||
decoder_hidden_states=decoder_hidden_states,
|
decoder_hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return GreedySearchDecoderOnlyOutput(
|
return GreedySearchDecoderOnlyOutput(
|
||||||
@@ -2605,6 +2699,7 @@ class GenerationMixin:
|
|||||||
scores=scores,
|
scores=scores,
|
||||||
attentions=decoder_attentions,
|
attentions=decoder_attentions,
|
||||||
hidden_states=decoder_hidden_states,
|
hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return input_ids
|
return input_ids
|
||||||
@@ -2880,6 +2975,7 @@ class GenerationMixin:
|
|||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
cross_attentions=cross_attentions,
|
cross_attentions=cross_attentions,
|
||||||
decoder_hidden_states=decoder_hidden_states,
|
decoder_hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return SampleDecoderOnlyOutput(
|
return SampleDecoderOnlyOutput(
|
||||||
@@ -2887,6 +2983,7 @@ class GenerationMixin:
|
|||||||
scores=scores,
|
scores=scores,
|
||||||
attentions=decoder_attentions,
|
attentions=decoder_attentions,
|
||||||
hidden_states=decoder_hidden_states,
|
hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return input_ids
|
return input_ids
|
||||||
@@ -3201,6 +3298,7 @@ class GenerationMixin:
|
|||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
cross_attentions=cross_attentions,
|
cross_attentions=cross_attentions,
|
||||||
decoder_hidden_states=decoder_hidden_states,
|
decoder_hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return BeamSearchDecoderOnlyOutput(
|
return BeamSearchDecoderOnlyOutput(
|
||||||
@@ -3210,6 +3308,7 @@ class GenerationMixin:
|
|||||||
beam_indices=sequence_outputs["beam_indices"],
|
beam_indices=sequence_outputs["beam_indices"],
|
||||||
attentions=decoder_attentions,
|
attentions=decoder_attentions,
|
||||||
hidden_states=decoder_hidden_states,
|
hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return sequence_outputs["sequences"]
|
return sequence_outputs["sequences"]
|
||||||
@@ -3530,6 +3629,7 @@ class GenerationMixin:
|
|||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
cross_attentions=cross_attentions,
|
cross_attentions=cross_attentions,
|
||||||
decoder_hidden_states=decoder_hidden_states,
|
decoder_hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return BeamSampleDecoderOnlyOutput(
|
return BeamSampleDecoderOnlyOutput(
|
||||||
@@ -3539,6 +3639,7 @@ class GenerationMixin:
|
|||||||
beam_indices=sequence_outputs["beam_indices"],
|
beam_indices=sequence_outputs["beam_indices"],
|
||||||
attentions=decoder_attentions,
|
attentions=decoder_attentions,
|
||||||
hidden_states=decoder_hidden_states,
|
hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return sequence_outputs["sequences"]
|
return sequence_outputs["sequences"]
|
||||||
@@ -3909,6 +4010,7 @@ class GenerationMixin:
|
|||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
cross_attentions=cross_attentions,
|
cross_attentions=cross_attentions,
|
||||||
decoder_hidden_states=decoder_hidden_states,
|
decoder_hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return BeamSearchDecoderOnlyOutput(
|
return BeamSearchDecoderOnlyOutput(
|
||||||
@@ -3918,6 +4020,7 @@ class GenerationMixin:
|
|||||||
beam_indices=sequence_outputs["beam_indices"],
|
beam_indices=sequence_outputs["beam_indices"],
|
||||||
attentions=decoder_attentions,
|
attentions=decoder_attentions,
|
||||||
hidden_states=decoder_hidden_states,
|
hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return sequence_outputs["sequences"]
|
return sequence_outputs["sequences"]
|
||||||
@@ -4244,6 +4347,7 @@ class GenerationMixin:
|
|||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
cross_attentions=cross_attentions,
|
cross_attentions=cross_attentions,
|
||||||
decoder_hidden_states=decoder_hidden_states,
|
decoder_hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return BeamSearchDecoderOnlyOutput(
|
return BeamSearchDecoderOnlyOutput(
|
||||||
@@ -4253,6 +4357,7 @@ class GenerationMixin:
|
|||||||
beam_indices=sequence_outputs["beam_indices"],
|
beam_indices=sequence_outputs["beam_indices"],
|
||||||
attentions=decoder_attentions,
|
attentions=decoder_attentions,
|
||||||
hidden_states=decoder_hidden_states,
|
hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return sequence_outputs["sequences"]
|
return sequence_outputs["sequences"]
|
||||||
@@ -4672,6 +4777,7 @@ class GenerationMixin:
|
|||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
cross_attentions=cross_attentions,
|
cross_attentions=cross_attentions,
|
||||||
decoder_hidden_states=decoder_hidden_states,
|
decoder_hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return GreedySearchDecoderOnlyOutput(
|
return GreedySearchDecoderOnlyOutput(
|
||||||
@@ -4679,6 +4785,7 @@ class GenerationMixin:
|
|||||||
scores=scores,
|
scores=scores,
|
||||||
attentions=decoder_attentions,
|
attentions=decoder_attentions,
|
||||||
hidden_states=decoder_hidden_states,
|
hidden_states=decoder_hidden_states,
|
||||||
|
past_key_values=model_kwargs.get("past_key_values"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|||||||
@@ -1829,6 +1829,85 @@ class GenerationTesterMixin:
|
|||||||
outputs_from_embeds_wo_ids[:, 1:].tolist(),
|
outputs_from_embeds_wo_ids[:, 1:].tolist(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_generate_continue_from_past_key_values(self):
|
||||||
|
# Tests that we can continue generating from past key values, returned from a previous `generate` call
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
# won't fix: old models with unique inputs/caches/others
|
||||||
|
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
|
||||||
|
return
|
||||||
|
# may fix in the future: needs modeling or test input preparation fixes for compatibility
|
||||||
|
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
|
||||||
|
return
|
||||||
|
|
||||||
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
# If it doesn't support cache, pass the test
|
||||||
|
if not hasattr(config, "use_cache"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Let's make it always:
|
||||||
|
# 1. use cache (for obvious reasons)
|
||||||
|
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
|
||||||
|
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
|
||||||
|
# continuation would force it to generate beyond an EOS token)
|
||||||
|
# 3. ignore `token_type_ids` for simplicity
|
||||||
|
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
|
||||||
|
# active by default on some models
|
||||||
|
config.use_cache = True
|
||||||
|
if "token_type_ids" in inputs:
|
||||||
|
del inputs["token_type_ids"]
|
||||||
|
|
||||||
|
model = model_class(config).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
||||||
|
model.generation_config.forced_eos_token_id = None
|
||||||
|
|
||||||
|
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
|
||||||
|
outputs = model(**inputs)
|
||||||
|
if "past_key_values" not in outputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
|
||||||
|
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
|
||||||
|
|
||||||
|
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
|
||||||
|
# inputs may need to be tweaked across `generate` calls (like the attention mask).
|
||||||
|
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True)
|
||||||
|
|
||||||
|
# Continue from the tokens generated above, preparing the inputs accordingly
|
||||||
|
inputs["past_key_values"] = outputs_cached.past_key_values
|
||||||
|
new_attention_len = outputs_cached.sequences.shape[-1]
|
||||||
|
if config.is_encoder_decoder:
|
||||||
|
inputs["decoder_input_ids"] = outputs_cached.sequences
|
||||||
|
if "decoder_attention_mask" in inputs:
|
||||||
|
inputs["decoder_attention_mask"] = torch.nn.functional.pad(
|
||||||
|
inputs["decoder_attention_mask"],
|
||||||
|
(0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
|
||||||
|
mode="constant",
|
||||||
|
value=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inputs["input_ids"] = outputs_cached.sequences
|
||||||
|
if "attention_mask" in inputs:
|
||||||
|
inputs["attention_mask"] = torch.nn.functional.pad(
|
||||||
|
inputs["attention_mask"],
|
||||||
|
(0, new_attention_len - inputs["attention_mask"].shape[1]),
|
||||||
|
mode="constant",
|
||||||
|
value=1,
|
||||||
|
)
|
||||||
|
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True)
|
||||||
|
|
||||||
|
# The two sets of generated text and past kv should be equal to each other
|
||||||
|
self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist())
|
||||||
|
for layer_idx in range(len(outputs_cached.past_key_values)):
|
||||||
|
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
outputs.past_key_values[layer_idx][kv_idx],
|
||||||
|
outputs_cached.past_key_values[layer_idx][kv_idx],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
num_sequences_in_output = batch_size * num_return_sequences
|
num_sequences_in_output = batch_size * num_return_sequences
|
||||||
@@ -1894,6 +1973,24 @@ class GenerationTesterMixin:
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Past Key Value States -- two notes here:
|
||||||
|
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
|
||||||
|
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
|
||||||
|
# 3. TODO (joao): A few models have different formats, skipping those until the cache refactor is complete
|
||||||
|
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer")
|
||||||
|
has_standard_cache = not any(
|
||||||
|
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
|
||||||
|
)
|
||||||
|
if use_cache and has_standard_cache:
|
||||||
|
past_key_values = output.past_key_values
|
||||||
|
past_sequence_length = output.sequences.shape[-1] - 1
|
||||||
|
self._check_past_key_values_for_generate(
|
||||||
|
num_sequences_in_output,
|
||||||
|
past_key_values,
|
||||||
|
seq_length=past_sequence_length,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
def _check_scores(self, batch_size, scores, length, config):
|
def _check_scores(self, batch_size, scores, length, config):
|
||||||
expected_shape = (batch_size, config.vocab_size)
|
expected_shape = (batch_size, config.vocab_size)
|
||||||
self.assertIsInstance(scores, tuple)
|
self.assertIsInstance(scores, tuple)
|
||||||
@@ -1959,6 +2056,30 @@ class GenerationTesterMixin:
|
|||||||
[encoder_expected_shape] * len(hidden_states),
|
[encoder_expected_shape] * len(hidden_states),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
|
||||||
|
self.assertIsInstance(past_key_values, tuple)
|
||||||
|
self.assertListEqual(
|
||||||
|
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
|
||||||
|
[True] * len(past_key_values),
|
||||||
|
)
|
||||||
|
|
||||||
|
# (batch, head, seq_length, head_features)
|
||||||
|
expected_shape = (
|
||||||
|
batch_size * num_beam_groups,
|
||||||
|
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
|
||||||
|
seq_length,
|
||||||
|
config.hidden_size // config.num_attention_heads,
|
||||||
|
)
|
||||||
|
# check shape key, value
|
||||||
|
self.assertListEqual(
|
||||||
|
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
|
||||||
|
[expected_shape] * len(past_key_values),
|
||||||
|
)
|
||||||
|
self.assertListEqual(
|
||||||
|
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
|
||||||
|
[expected_shape] * len(past_key_values),
|
||||||
|
)
|
||||||
|
|
||||||
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
||||||
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
|
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
|
||||||
# set to same device. we don't care what device.
|
# set to same device. we don't care what device.
|
||||||
|
|||||||
Reference in New Issue
Block a user