From 9e89390ce1e785e72452207139a334cd3bf745ff Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 14 Sep 2020 20:33:08 -0400 Subject: [PATCH] [QOL] add signature for prepare_seq2seq_batch (#7108) --- src/transformers/tokenization_bart.py | 4 +- src/transformers/tokenization_marian.py | 4 +- src/transformers/tokenization_mbart.py | 6 +- src/transformers/tokenization_t5.py | 7 +-- src/transformers/tokenization_utils.py | 77 +++++++++++++++++++++++++ tests/test_tokenization_common.py | 19 +++--- 6 files changed, 96 insertions(+), 21 deletions(-) diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index 030917c3c3..22a836b025 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -111,9 +111,7 @@ class BartTokenizer(RobertaTokenizer): - **input_ids** -- List of token ids to be fed to the encoder. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. - - **decoder_input_ids** -- List of token ids to be fed to the decoder. - - **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder. - This does not include causal mask, which is built by the model. + - **labels** -- List of token ids for tgt_texts The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. diff --git a/src/transformers/tokenization_marian.py b/src/transformers/tokenization_marian.py index 3f06092b53..e1ce86ca02 100644 --- a/src/transformers/tokenization_marian.py +++ b/src/transformers/tokenization_marian.py @@ -33,12 +33,12 @@ class MarianTokenizer(PreTrainedTokenizer): >>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."] >>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional >>> batch_enc: BatchEncoding = tok.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts) - >>> # keys [input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]. + >>> # keys [input_ids, attention_mask, labels]. >>> # model(**batch) should work """ vocab_files_names = vocab_files_names - model_input_names = ["attention_mask"] # actually attention_mask, decoder_attention_mask + model_input_names = ["attention_mask"] language_code_re = re.compile(">>.+<<") # type: re.Pattern def __init__( diff --git a/src/transformers/tokenization_mbart.py b/src/transformers/tokenization_mbart.py index dc78065931..a5c72576aa 100644 --- a/src/transformers/tokenization_mbart.py +++ b/src/transformers/tokenization_mbart.py @@ -225,11 +225,9 @@ class MBartTokenizer(XLMRobertaTokenizer): - **input_ids** -- List of token ids to be fed to the encoder. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. - - **decoder_input_ids** -- List of token ids to be fed to the decoder. - - **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder. - This does not include causal mask, which is built by the model. + - **labels** -- List of token ids for tgt_texts - The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, + The full set of keys ``[input_ids, attention_mask, decoder_input_ids, labels]``, will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. """ diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index 6a41da017e..a5569b6739 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -333,10 +333,9 @@ class T5Tokenizer(PreTrainedTokenizer): :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: - **input_ids** -- List of token ids to be fed to the encoder. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. - - **decoder_input_ids** -- List of token ids to be fed to the decoder. - - **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder. - This does not include causal mask, which is built by the model. - The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, + - **labels** -- List of token ids for tgt_texts + + The full set of keys ``[input_ids, attention_mask, decoder_input_ids, labels]``, will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. """ if max_length is None: diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 9eaf5bfc3d..b19fd90e32 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -777,3 +777,80 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): A tuple of :obj:`str`: The files saved. """ raise NotImplementedError + + def prepare_seq2seq_batch( + self, + src_texts: List[str], + tgt_texts: Optional[List[str]] = None, + max_length: Optional[int] = None, + max_target_length: Optional[int] = None, + padding: str = "longest", + return_tensors: str = "None", + truncation=True, + **kwargs, + ) -> BatchEncoding: + r""" + + Prepare a batch that can be passed directly to an instance of :class:`~transformers.AutoModelForSeq2SeqLM`. + + Args: + src_texts: (:obj:`List[str]`): + List of documents to summarize or source language texts. + tgt_texts: (:obj:`List[str]`, `optional`): + List of summaries or target language texts. + max_length (:obj:`int`, `optional`): + Controls the maximum length for encoder inputs (documents to summarize or source language texts). + If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum + length is required by one of the truncation/padding parameters. If the model has no specific maximum + input length (like XLNet) truncation/padding to a maximum length will be deactivated. + max_target_length (:obj:`int`, `optional`): + Controls the maximum length of decoder inputs (target language texts or summaries). + If left unset or set to :obj:`None`, this will use the max_length value. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): + Activates and controls padding. Accepts the following values: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): + Activates and controls truncation. Accepts the following values: + + * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument + :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not + provided. This will truncate token by token, removing a token from the longest sequence in the pair + if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to + the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with + sequence lengths greater than the model maximum admissible input size). + **kwargs: + Additional keyword arguments passed along to :obj:`self.__call__`. + + Returns: + :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: + + - **input_ids** -- List of token ids to be fed to the encoder. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **labels** -- List of token ids for tgt_texts + + The full set of keys ``[input_ids, attention_mask, labels]``, + will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. + """ + raise NotImplementedError( + "If your model requires more than input_ids for a typical forward pass, you should implement this method. " + "Returned keys should be [input_ids, attention_mask, labels]. See MarianTokenizer or T5Tokenizer for a " + "reference implementation." + ) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index f502428ce8..2972a80212 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1566,14 +1566,17 @@ class TokenizerTesterMixin: 'pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu ' "vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.", ] - batch = tokenizer.prepare_seq2seq_batch( - src_texts=src_text, - tgt_texts=tgt_text, - max_length=3, - max_target_length=10, - return_tensors="pt", - src_lang="en_XX", # this should be ignored (for all but mbart) but not cause an error - ) + try: + batch = tokenizer.prepare_seq2seq_batch( + src_texts=src_text, + tgt_texts=tgt_text, + max_length=3, + max_target_length=10, + return_tensors="pt", + src_lang="en_XX", # this should be ignored (for all but mbart) but not cause an error + ) + except NotImplementedError: + return self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.labels.shape[1], 10) # max_target_length will default to max_length if not specified