Deprecate old past arguments (#5671)
This commit is contained in:
@@ -690,7 +690,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
|
|||||||
if "masked_lm_labels" in kwargs:
|
if "masked_lm_labels" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||||
DeprecationWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
labels = kwargs.pop("masked_lm_labels")
|
labels = kwargs.pop("masked_lm_labels")
|
||||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
|
|||||||
@@ -111,6 +111,15 @@ BART_INPUTS_DOCSTRING = r"""
|
|||||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||||
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
|
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
|
||||||
See diagram 1 in the paper for more info on the default strategy
|
See diagram 1 in the paper for more info on the default strategy
|
||||||
|
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||||
|
Can be used to speed up decoding.
|
||||||
|
If ``decoder_past_key_value_states`` are used, the user can optionally input only the last
|
||||||
|
``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape
|
||||||
|
:obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
If `use_cache` is True, ``decoder_past_key_values`` are returned and can be used to speed up decoding (see
|
||||||
|
``decoder_past_key_values``).
|
||||||
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||||
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
||||||
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||||
@@ -482,7 +491,7 @@ class BartDecoder(nn.Module):
|
|||||||
encoder_padding_mask,
|
encoder_padding_mask,
|
||||||
decoder_padding_mask,
|
decoder_padding_mask,
|
||||||
decoder_causal_mask,
|
decoder_causal_mask,
|
||||||
decoder_cached_states=None,
|
decoder_past_key_values=None,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
@@ -499,7 +508,7 @@ class BartDecoder(nn.Module):
|
|||||||
encoder_hidden_states: output from the encoder, used for
|
encoder_hidden_states: output from the encoder, used for
|
||||||
encoder-side attention
|
encoder-side attention
|
||||||
encoder_padding_mask: for ignoring pad tokens
|
encoder_padding_mask: for ignoring pad tokens
|
||||||
decoder_cached_states (dict or None): dictionary used for storing state during generation
|
decoder_past_key_values (dict or None): dictionary used for storing state during generation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BaseModelOutputWithPast or tuple:
|
BaseModelOutputWithPast or tuple:
|
||||||
@@ -508,6 +517,13 @@ class BartDecoder(nn.Module):
|
|||||||
- hidden states
|
- hidden states
|
||||||
- attentions
|
- attentions
|
||||||
"""
|
"""
|
||||||
|
if "decoder_cached_states" in unused:
|
||||||
|
warnings.warn(
|
||||||
|
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
decoder_past_key_values = unused.pop("decoder_cached_states")
|
||||||
|
|
||||||
# check attention mask and invert
|
# check attention mask and invert
|
||||||
if encoder_padding_mask is not None:
|
if encoder_padding_mask is not None:
|
||||||
encoder_padding_mask = invert_mask(encoder_padding_mask)
|
encoder_padding_mask = invert_mask(encoder_padding_mask)
|
||||||
@@ -541,7 +557,7 @@ class BartDecoder(nn.Module):
|
|||||||
if self.training and (dropout_probability < self.layerdrop):
|
if self.training and (dropout_probability < self.layerdrop):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
layer_state = decoder_cached_states[idx] if decoder_cached_states is not None else None
|
layer_state = decoder_past_key_values[idx] if decoder_past_key_values is not None else None
|
||||||
|
|
||||||
x, layer_self_attn, layer_past = decoder_layer(
|
x, layer_self_attn, layer_past = decoder_layer(
|
||||||
x,
|
x,
|
||||||
@@ -854,11 +870,12 @@ class BartModel(PretrainedBartModel):
|
|||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
encoder_outputs: Optional[Tuple] = None,
|
encoder_outputs: Optional[Tuple] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
decoder_cached_states=None,
|
decoder_past_key_values=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_tuple=None,
|
return_tuple=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
if decoder_input_ids is None:
|
if decoder_input_ids is None:
|
||||||
@@ -908,7 +925,7 @@ class BartModel(PretrainedBartModel):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
decoder_padding_mask,
|
decoder_padding_mask,
|
||||||
decoder_causal_mask=causal_mask,
|
decoder_causal_mask=causal_mask,
|
||||||
decoder_cached_states=decoder_cached_states,
|
decoder_past_key_values=decoder_past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -977,7 +994,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
decoder_cached_states=None,
|
decoder_past_key_values=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
@@ -1015,9 +1032,15 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
if "lm_labels" in unused:
|
if "lm_labels" in unused:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||||
DeprecationWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
labels = unused.pop("lm_labels")
|
labels = unused.pop("lm_labels")
|
||||||
|
if "decoder_cached_states" in unused:
|
||||||
|
warnings.warn(
|
||||||
|
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
decoder_past_key_values = unused.pop("decoder_cached_states")
|
||||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
@@ -1029,7 +1052,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
decoder_cached_states=decoder_cached_states,
|
decoder_past_key_values=decoder_past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1061,11 +1084,11 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
|
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
|
||||||
assert past is not None, "past has to be defined for encoder_outputs"
|
assert past is not None, "past has to be defined for encoder_outputs"
|
||||||
|
|
||||||
encoder_outputs, decoder_cached_states = past
|
encoder_outputs, decoder_past_key_values = past
|
||||||
return {
|
return {
|
||||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
"decoder_cached_states": decoder_cached_states,
|
"decoder_past_key_values": decoder_past_key_values,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
@@ -1092,9 +1115,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
((enc_out, enc_mask), decoder_cached_states) = past
|
((enc_out, enc_mask), decoder_past_key_values) = past
|
||||||
reordered_past = []
|
reordered_past = []
|
||||||
for layer_past in decoder_cached_states:
|
for layer_past in decoder_past_key_values:
|
||||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||||||
layer_past_new = {
|
layer_past_new = {
|
||||||
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
||||||
|
|||||||
@@ -879,7 +879,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
if "masked_lm_labels" in kwargs:
|
if "masked_lm_labels" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||||
DeprecationWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
labels = kwargs.pop("masked_lm_labels")
|
labels = kwargs.pop("masked_lm_labels")
|
||||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
@@ -1076,7 +1076,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
if "masked_lm_labels" in kwargs:
|
if "masked_lm_labels" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||||
DeprecationWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
labels = kwargs.pop("masked_lm_labels")
|
labels = kwargs.pop("masked_lm_labels")
|
||||||
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
|
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -246,20 +247,22 @@ CTRL_START_DOCSTRING = r"""
|
|||||||
CTRL_INPUTS_DOCSTRING = r"""
|
CTRL_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
|
||||||
:obj:`input_ids_length` = ``sequence_length`` if ``past`` is ``None`` else ``past[0].shape[-2]`` (``sequence_length`` of input past key value states).
|
:obj:`input_ids_length` = ``sequence_length`` if ``past_key_values`` is ``None`` else
|
||||||
|
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states).
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
If `past` is used, only input_ids that do not have their past calculated should be passed as input_ids.
|
If ``past_key_values`` is used, only input_ids that do not have their past calculated should be passed as
|
||||||
|
``input_ids``.
|
||||||
|
|
||||||
Indices can be obtained using :class:`transformers.CTRLTokenizer`.
|
Indices can be obtained using :class:`transformers.CTRLTokenizer`.
|
||||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||||
:func:`transformers.PreTrainedTokenizer.__call__` for details.
|
:func:`transformers.PreTrainedTokenizer.__call__` for details.
|
||||||
|
|
||||||
`What are input IDs? <../glossary.html#input-ids>`__
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||||
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||||
(see `past` output below). Can be used to speed up sequential decoding.
|
(see ``past_key_values`` output below). Can be used to speed up sequential decoding.
|
||||||
The input_ids which have their past given to this model should not be passed as input ids as they have already been computed.
|
The ``input_ids`` which have their past given to this model should not be passed as input ids as they have already been computed.
|
||||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
Mask to avoid performing attention on padding token indices.
|
Mask to avoid performing attention on padding token indices.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
@@ -284,10 +287,10 @@ CTRL_INPUTS_DOCSTRING = r"""
|
|||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
than the model's internal embedding lookup matrix.
|
than the model's internal embedding lookup matrix.
|
||||||
If `past` is used, optionally only the last `inputs_embeds` have to be input (see `past`).
|
If ``past_key_values`` is used, optionally only the last `inputs_embeds` have to be input (see ``past_key_values``).
|
||||||
use_cache (:obj:`bool`):
|
use_cache (:obj:`bool`):
|
||||||
If `use_cache` is True, `past` key value states are returned and
|
If `use_cache` is True, ``past_key_values`` key value states are returned and
|
||||||
can be used to speed up decoding (see `past`). Defaults to `True`.
|
can be used to speed up decoding (see ``past_key_values``). Defaults to `True`.
|
||||||
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||||
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
||||||
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||||
@@ -343,7 +346,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
past=None,
|
past_key_values=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
@@ -353,7 +356,16 @@ class CTRLModel(CTRLPreTrainedModel):
|
|||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_tuple=None,
|
return_tuple=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if "past" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = kwargs.pop("past")
|
||||||
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -373,11 +385,11 @@ class CTRLModel(CTRLPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if past is None:
|
if past_key_values is None:
|
||||||
past_length = 0
|
past_length = 0
|
||||||
past = [None] * len(self.h)
|
past_key_values = [None] * len(self.h)
|
||||||
else:
|
else:
|
||||||
past_length = past[0][0].size(-2)
|
past_length = past_key_values[0][0].size(-2)
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||||
@@ -431,7 +443,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
|||||||
presents = () if use_cache else None
|
presents = () if use_cache else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_attentions = [] if output_attentions else None
|
all_attentions = [] if output_attentions else None
|
||||||
for i, (h, layer_past) in enumerate(zip(self.h, past)):
|
for i, (h, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||||
outputs = h(
|
outputs = h(
|
||||||
@@ -492,7 +504,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||||||
if past:
|
if past:
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
|
return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]}
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
@@ -504,7 +516,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
past=None,
|
past_key_values=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
@@ -515,6 +527,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_tuple=None,
|
return_tuple=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
@@ -524,11 +537,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||||||
All labels set to ``-100`` are ignored (masked), the loss is only
|
All labels set to ``-100`` are ignored (masked), the loss is only
|
||||||
computed for labels in ``[0, ..., config.vocab_size]``
|
computed for labels in ``[0, ..., config.vocab_size]``
|
||||||
"""
|
"""
|
||||||
|
if "past" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = kwargs.pop("past")
|
||||||
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
past=past,
|
past_key_values=past_key_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
|||||||
@@ -531,7 +531,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
|||||||
if "masked_lm_labels" in kwargs:
|
if "masked_lm_labels" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||||
DeprecationWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
labels = kwargs.pop("masked_lm_labels")
|
labels = kwargs.pop("masked_lm_labels")
|
||||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
|
|||||||
@@ -622,7 +622,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
|
|||||||
if "masked_lm_labels" in kwargs:
|
if "masked_lm_labels" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||||
DeprecationWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
labels = kwargs.pop("masked_lm_labels")
|
labels = kwargs.pop("masked_lm_labels")
|
||||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
|
|||||||
@@ -347,10 +347,12 @@ GPT2_START_DOCSTRING = r"""
|
|||||||
GPT2_INPUTS_DOCSTRING = r"""
|
GPT2_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
|
||||||
:obj:`input_ids_length` = ``sequence_length`` if ``past`` is ``None`` else ``past[0].shape[-2]`` (``sequence_length`` of input past key value states).
|
:obj:`input_ids_length` = ``sequence_length`` if ``past_key_values`` is ``None`` else
|
||||||
|
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states).
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
If `past` is used, only `input_ids` that do not have their past calculated should be passed as `input_ids`.
|
If ``past_key_values`` is used, only ``input_ids`` that do not have their past calculated should be passed
|
||||||
|
as ``input_ids``.
|
||||||
|
|
||||||
Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
|
Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
|
||||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||||
@@ -358,10 +360,10 @@ GPT2_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
`What are input IDs? <../glossary.html#input-ids>`__
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
|
||||||
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||||
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||||
(see `past` output below). Can be used to speed up sequential decoding.
|
(see ``past_key_values`` output below). Can be used to speed up sequential decoding.
|
||||||
The `input_ids` which have their past given to this model should not be passed as `input_ids` as they have already been computed.
|
The ``input_ids`` which have their past given to this model should not be passed as ``input_ids`` as they have already been computed.
|
||||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
Mask to avoid performing attention on padding token indices.
|
Mask to avoid performing attention on padding token indices.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
@@ -386,9 +388,9 @@ GPT2_INPUTS_DOCSTRING = r"""
|
|||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
than the model's internal embedding lookup matrix.
|
than the model's internal embedding lookup matrix.
|
||||||
If `past` is used, optionally only the last `inputs_embeds` have to be input (see `past`).
|
If ``past_key_values`` is used, optionally only the last `inputs_embeds` have to be input (see ``past_key_values``).
|
||||||
use_cache (:obj:`bool`):
|
use_cache (:obj:`bool`):
|
||||||
If `use_cache` is True, `past` key value states are returned and can be used to speed up decoding (see `past`). Defaults to `True`.
|
If `use_cache` is True, ``past_key_values`` key value states are returned and can be used to speed up decoding (see ``past_key_values``). Defaults to `True`.
|
||||||
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||||
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
||||||
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||||
@@ -437,7 +439,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
past=None,
|
past_key_values=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
@@ -447,7 +449,16 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_tuple=None,
|
return_tuple=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if "past" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = kwargs.pop("past")
|
||||||
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@@ -472,11 +483,11 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
if position_ids is not None:
|
if position_ids is not None:
|
||||||
position_ids = position_ids.view(-1, input_shape[-1])
|
position_ids = position_ids.view(-1, input_shape[-1])
|
||||||
|
|
||||||
if past is None:
|
if past_key_values is None:
|
||||||
past_length = 0
|
past_length = 0
|
||||||
past = [None] * len(self.h)
|
past_key_values = [None] * len(self.h)
|
||||||
else:
|
else:
|
||||||
past_length = past[0][0].size(-2)
|
past_length = past_key_values[0][0].size(-2)
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||||
@@ -522,7 +533,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
presents = () if use_cache else None
|
presents = () if use_cache else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past)):
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||||
|
|
||||||
@@ -581,7 +592,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
if past:
|
if past:
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
|
return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]}
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
@@ -593,7 +604,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
past=None,
|
past_key_values=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
@@ -604,6 +615,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_tuple=None,
|
return_tuple=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
@@ -613,11 +625,18 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
All labels set to ``-100`` are ignored (masked), the loss is only
|
All labels set to ``-100`` are ignored (masked), the loss is only
|
||||||
computed for labels in ``[0, ..., config.vocab_size]``
|
computed for labels in ``[0, ..., config.vocab_size]``
|
||||||
"""
|
"""
|
||||||
|
if "past" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = kwargs.pop("past")
|
||||||
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
past=past,
|
past_key_values=past_key_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -680,7 +699,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
past=None,
|
past_key_values=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
@@ -693,7 +712,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_tuple=None,
|
return_tuple=None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
|
mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
|
||||||
@@ -741,15 +760,21 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
if "lm_labels" in kwargs:
|
if "lm_labels" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||||
DeprecationWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
labels = kwargs.pop("lm_labels")
|
labels = kwargs.pop("lm_labels")
|
||||||
|
if "past" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
past_key_values = kwargs.pop("past")
|
||||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
past=past,
|
past_key_values=past_key_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
|||||||
@@ -1094,7 +1094,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
|
|||||||
if "masked_lm_labels" in kwargs:
|
if "masked_lm_labels" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||||
DeprecationWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
labels = kwargs.pop("masked_lm_labels")
|
labels = kwargs.pop("masked_lm_labels")
|
||||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
|
|||||||
@@ -665,7 +665,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
if "lm_labels" in kwargs:
|
if "lm_labels" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||||
DeprecationWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
labels = kwargs.pop("lm_labels")
|
labels = kwargs.pop("lm_labels")
|
||||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
|
|||||||
if "masked_lm_labels" in kwargs:
|
if "masked_lm_labels" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||||
DeprecationWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
labels = kwargs.pop("masked_lm_labels")
|
labels = kwargs.pop("masked_lm_labels")
|
||||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
|
|||||||
@@ -836,27 +836,27 @@ T5_INPUTS_DOCSTRING = r"""
|
|||||||
Used in the cross-attention of the decoder.
|
Used in the cross-attention of the decoder.
|
||||||
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||||
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
|
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
|
||||||
If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`).
|
If `decoder_past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_values`).
|
||||||
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
|
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
|
||||||
`T5 Training <./t5.html#training>`__. If decoder_input_ids and decoder_inputs_embeds are both None,
|
`T5 Training <./t5.html#training>`__. If decoder_input_ids and decoder_inputs_embeds are both None,
|
||||||
decoder_input_ids takes the value of input_ids.
|
decoder_input_ids takes the value of input_ids.
|
||||||
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
||||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||||
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
decoder_past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains pre-computed key and value hidden-states of the attention blocks.
|
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||||
Can be used to speed up decoding.
|
Can be used to speed up decoding.
|
||||||
If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids`
|
If `decoder_past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
|
||||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`).
|
If `use_cache` is True, `decoder_past_key_values` are returned and can be used to speed up decoding (see `decoder_past_key_values`).
|
||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
than the model's internal embedding lookup matrix.
|
than the model's internal embedding lookup matrix.
|
||||||
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||||
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
|
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
|
||||||
If `decoder_past_key_value_states` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_value_states`).
|
If `decoder_past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_values`).
|
||||||
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
|
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
|
||||||
than the model's internal embedding lookup matrix. If decoder_input_ids and decoder_inputs_embeds are both None,
|
than the model's internal embedding lookup matrix. If decoder_input_ids and decoder_inputs_embeds are both None,
|
||||||
decoder_inputs_embeds takes the value of inputs_embeds.
|
decoder_inputs_embeds takes the value of inputs_embeds.
|
||||||
@@ -923,7 +923,7 @@ class T5Model(T5PreTrainedModel):
|
|||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
decoder_past_key_value_states=None,
|
decoder_past_key_values=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds=None,
|
||||||
@@ -931,6 +931,7 @@ class T5Model(T5PreTrainedModel):
|
|||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_tuple=None,
|
return_tuple=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@@ -947,6 +948,14 @@ class T5Model(T5PreTrainedModel):
|
|||||||
|
|
||||||
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||||
"""
|
"""
|
||||||
|
if "decoder_past_key_value_states" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
decoder_past_key_values = kwargs.pop("decoder_past_key_value_states")
|
||||||
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||||
|
|
||||||
@@ -978,7 +987,7 @@ class T5Model(T5PreTrainedModel):
|
|||||||
|
|
||||||
# If decoding with past key value states, only the last tokens
|
# If decoding with past key value states, only the last tokens
|
||||||
# should be given as an input
|
# should be given as an input
|
||||||
if decoder_past_key_value_states is not None:
|
if decoder_past_key_values is not None:
|
||||||
if decoder_input_ids is not None:
|
if decoder_input_ids is not None:
|
||||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||||
if decoder_inputs_embeds is not None:
|
if decoder_inputs_embeds is not None:
|
||||||
@@ -989,7 +998,7 @@ class T5Model(T5PreTrainedModel):
|
|||||||
input_ids=decoder_input_ids,
|
input_ids=decoder_input_ids,
|
||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
past_key_value_states=decoder_past_key_value_states,
|
past_key_value_states=decoder_past_key_values,
|
||||||
encoder_hidden_states=hidden_states,
|
encoder_hidden_states=hidden_states,
|
||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
@@ -1062,7 +1071,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
decoder_past_key_value_states=None,
|
decoder_past_key_values=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1071,7 +1080,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_tuple=None,
|
return_tuple=None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
@@ -1103,9 +1112,15 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
if "lm_labels" in kwargs:
|
if "lm_labels" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
||||||
DeprecationWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
labels = kwargs.pop("lm_labels")
|
labels = kwargs.pop("lm_labels")
|
||||||
|
if "decoder_past_key_value_states" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
decoder_past_key_values = kwargs.pop("decoder_past_key_value_states")
|
||||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||||
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
@@ -1138,7 +1153,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
|
|
||||||
# If decoding with past key value states, only the last tokens
|
# If decoding with past key value states, only the last tokens
|
||||||
# should be given as an input
|
# should be given as an input
|
||||||
if decoder_past_key_value_states is not None:
|
if decoder_past_key_values is not None:
|
||||||
assert labels is None, "Decoder should not use cached key value states when training."
|
assert labels is None, "Decoder should not use cached key value states when training."
|
||||||
if decoder_input_ids is not None:
|
if decoder_input_ids is not None:
|
||||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||||
@@ -1150,7 +1165,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
input_ids=decoder_input_ids,
|
input_ids=decoder_input_ids,
|
||||||
attention_mask=decoder_attention_mask,
|
attention_mask=decoder_attention_mask,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
past_key_value_states=decoder_past_key_value_states,
|
past_key_value_states=decoder_past_key_values,
|
||||||
encoder_hidden_states=hidden_states,
|
encoder_hidden_states=hidden_states,
|
||||||
encoder_attention_mask=attention_mask,
|
encoder_attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
@@ -1193,11 +1208,11 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs):
|
||||||
assert past is not None, "past has to be defined for encoder_outputs"
|
assert past is not None, "past has to be defined for encoder_outputs"
|
||||||
|
|
||||||
encoder_outputs, decoder_past_key_value_states = past
|
encoder_outputs, decoder_past_key_values = past
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"decoder_input_ids": input_ids,
|
"decoder_input_ids": input_ids,
|
||||||
"decoder_past_key_value_states": decoder_past_key_value_states,
|
"decoder_past_key_values": decoder_past_key_values,
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
|
|||||||
Reference in New Issue
Block a user