Deprecate old past arguments (#5671)

This commit is contained in:
Sylvain Gugger
2020-07-10 17:25:52 -04:00
committed by GitHub
parent cdf4cd7068
commit df983b7483
11 changed files with 153 additions and 70 deletions

View File

@@ -17,6 +17,7 @@
import logging
import warnings
import numpy as np
import torch
@@ -246,20 +247,22 @@ CTRL_START_DOCSTRING = r"""
CTRL_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
:obj:`input_ids_length` = ``sequence_length`` if ``past`` is ``None`` else ``past[0].shape[-2]`` (``sequence_length`` of input past key value states).
:obj:`input_ids_length` = ``sequence_length`` if ``past_key_values`` is ``None`` else
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states).
Indices of input sequence tokens in the vocabulary.
If `past` is used, only input_ids that do not have their past calculated should be passed as input_ids.
If ``past_key_values`` is used, only input_ids that do not have their past calculated should be passed as
``input_ids``.
Indices can be obtained using :class:`transformers.CTRLTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.__call__` for details.
`What are input IDs? <../glossary.html#input-ids>`__
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding.
The input_ids which have their past given to this model should not be passed as input ids as they have already been computed.
(see ``past_key_values`` output below). Can be used to speed up sequential decoding.
The ``input_ids`` which have their past given to this model should not be passed as input ids as they have already been computed.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
@@ -284,10 +287,10 @@ CTRL_INPUTS_DOCSTRING = r"""
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
If `past` is used, optionally only the last `inputs_embeds` have to be input (see `past`).
If ``past_key_values`` is used, optionally only the last `inputs_embeds` have to be input (see ``past_key_values``).
use_cache (:obj:`bool`):
If `use_cache` is True, `past` key value states are returned and
can be used to speed up decoding (see `past`). Defaults to `True`.
If `use_cache` is True, ``past_key_values`` key value states are returned and
can be used to speed up decoding (see ``past_key_values``). Defaults to `True`.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
@@ -343,7 +346,7 @@ class CTRLModel(CTRLPreTrainedModel):
def forward(
self,
input_ids=None,
past=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -353,7 +356,16 @@ class CTRLModel(CTRLPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**kwargs,
):
if "past" in kwargs:
warnings.warn(
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("past")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_hidden_states = (
@@ -373,11 +385,11 @@ class CTRLModel(CTRLPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past is None:
if past_key_values is None:
past_length = 0
past = [None] * len(self.h)
past_key_values = [None] * len(self.h)
else:
past_length = past[0][0].size(-2)
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
@@ -431,7 +443,7 @@ class CTRLModel(CTRLPreTrainedModel):
presents = () if use_cache else None
all_hidden_states = () if output_hidden_states else None
all_attentions = [] if output_attentions else None
for i, (h, layer_past) in enumerate(zip(self.h, past)):
for i, (h, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = h(
@@ -492,7 +504,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
@@ -504,7 +516,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def forward(
self,
input_ids=None,
past=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@@ -515,6 +527,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**kwargs,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@@ -524,11 +537,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
"""
if "past" in kwargs:
warnings.warn(
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("past")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer(
input_ids,
past=past,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,