diff --git a/docs/source/model_doc/bart.rst b/docs/source/model_doc/bart.rst index 21416045f0..81d138232d 100644 --- a/docs/source/model_doc/bart.rst +++ b/docs/source/model_doc/bart.rst @@ -53,7 +53,7 @@ MBartTokenizer ~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.MBartTokenizer - :members: build_inputs_with_special_tokens, prepare_translation_batch + :members: build_inputs_with_special_tokens, prepare_seq2seq_batch diff --git a/docs/source/model_doc/marian.rst b/docs/source/model_doc/marian.rst index c4e64d61ee..8052d14372 100644 --- a/docs/source/model_doc/marian.rst +++ b/docs/source/model_doc/marian.rst @@ -48,7 +48,7 @@ Example of translating english to many romance languages, using language codes: tokenizer = MarianTokenizer.from_pretrained(model_name) print(tokenizer.supported_language_codes) model = MarianMTModel.from_pretrained(model_name) - translated = model.generate(**tokenizer.prepare_translation_batch(src_text)) + translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text)) tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated] # ["c'est une phrase en anglais que nous voulons traduire en français", # 'Isto deve ir para o português.', @@ -86,6 +86,14 @@ Code to see available pretrained models: suffix = [x.split('/')[1] for x in model_ids] multi_models = [f'{org}/{s}' for s in suffix if s != s.lower()] +MarianMTModel +~~~~~~~~~~~~~ + +Pytorch version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints. +Model API is identical to BartForConditionalGeneration. +Available models are listed at `Model List `__ +This class inherits nearly all functionality from ``BartForConditionalGeneration``, see that page for method signatures. + MarianConfig ~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.MarianConfig @@ -96,16 +104,8 @@ MarianTokenizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.MarianTokenizer - :members: prepare_translation_batch + :members: prepare_seq2seq_batch -MarianMTModel -~~~~~~~~~~~~~ -Pytorch version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints. -Model API is identical to BartForConditionalGeneration. -Available models are listed at `Model List `__ -This class inherits all functionality from ``BartForConditionalGeneration``, see that page for method signatures. -.. autoclass:: transformers.MarianMTModel - :members: diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index 63b5b07820..3a37cc5e5d 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -63,7 +63,7 @@ Summarization Tips: (It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). **Update 2018-07-18** -Datasets: `Seq2SeqDataset` should be used for all tokenizers without a `prepare_translation_batch` method. For those who do (like Marian, MBart), `TranslationDataset` should be used.** +Datasets: `Seq2SeqDataset` should be used for all tokenizers without a `prepare_seq2seq_batch` method. For those who do (like Marian, MBart), `TranslationDataset` should be used.** A new dataset is needed to support multilingual tasks. diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 20440e3379..8c8b3005c4 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -145,7 +145,7 @@ class Seq2SeqDataset(Dataset): class TranslationDataset(Seq2SeqDataset): - """A dataset that calls prepare_translation_batch.""" + """A dataset that calls prepare_seq2seq_batch.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -167,7 +167,7 @@ class TranslationDataset(Seq2SeqDataset): } def collate_fn(self, batch) -> Dict[str, torch.Tensor]: - batch_encoding = self.tokenizer.prepare_translation_batch( + batch_encoding = self.tokenizer.prepare_seq2seq_batch( [x["src_texts"] for x in batch], src_lang=self.src_lang, tgt_texts=[x["tgt_texts"] for x in batch], diff --git a/src/transformers/modeling_marian.py b/src/transformers/modeling_marian.py index 0007641e03..bde0c62788 100644 --- a/src/transformers/modeling_marian.py +++ b/src/transformers/modeling_marian.py @@ -40,7 +40,7 @@ class MarianMTModel(BartForConditionalGeneration): >>> model = MarianMTModel.from_pretrained(mname) >>> tok = MarianTokenizer.from_pretrained(mname) - >>> batch = tok.prepare_translation_batch(src_texts=[sample_text]) # don't need tgt_text for inference + >>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text]) # don't need tgt_text for inference >>> gen = model.generate(**batch) # for forward pass: model(**batch) >>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the the bus stop ?" diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index 499895e0bd..2348ce86d6 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -16,8 +16,10 @@ import logging from typing import List, Optional +from .file_utils import add_start_docstrings_to_callable from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast from .tokenization_utils import BatchEncoding +from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING from .tokenization_xlm_roberta import XLMRobertaTokenizer @@ -89,7 +91,7 @@ FAIRSEQ_LANGUAGE_CODES = [ class MBartTokenizer(XLMRobertaTokenizer): """ - This inherits from XLMRobertaTokenizer. ``prepare_translation_batch`` should be used to encode inputs. + This inherits from XLMRobertaTokenizer. ``prepare_seq2seq_batch`` should be used to encode inputs. Other tokenizer methods like ``encode`` do not work properly. The tokenization method is `` `` for source language documents, and `` ``` for target language documents. @@ -100,7 +102,7 @@ class MBartTokenizer(XLMRobertaTokenizer): >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro') >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" - >>> batch: dict = tokenizer.prepare_translation_batch( + >>> batch: dict = tokenizer.prepare_seq2seq_batch( ... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian ... ) @@ -187,7 +189,8 @@ class MBartTokenizer(XLMRobertaTokenizer): return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones - def prepare_translation_batch( + @add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) + def prepare_seq2seq_batch( self, src_texts: List[str], src_lang: str = "en_XX", @@ -195,22 +198,73 @@ class MBartTokenizer(XLMRobertaTokenizer): tgt_lang: str = "ro_RO", max_length: Optional[int] = None, max_target_length: Optional[int] = None, + truncation: bool = True, padding: str = "longest", return_tensors: str = "pt", **kwargs, ) -> BatchEncoding: """Prepare a batch that can be passed directly to an instance of MBartModel. - Arguments: - src_texts: list of src language texts - src_lang: default en_XX (english), the language we are translating from - tgt_texts: list of tgt language texts - tgt_lang: default ro_RO (romanian), the language we are translating to - max_length: (default=None, which defers to the config value of 1024 for facebook/mbart-large* - padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest. - **kwargs: passed to self.__call__ - Returns: - :obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask. + Arguments: + src_texts: (:obj:`list`): + list of documents to summarize or source language texts + src_lang: (:obj:`str`, `optional`, default='en_XX'): + default en_XX (english), the language we are translating from + tgt_texts: (:obj:`list`, `optional`): + list of tgt language texts or summaries. + tgt_lang: (:obj:`str`, `optional`, default='ro_RO'): + default ro_RO (romanian), the language we are translating to + max_length (:obj:`int`, `optional`): + Controls the maximum length for encoder inputs (documents to summarize or source language texts) + If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum + length is required by one of the truncation/padding parameters. If the model has no specific maximum + input length (like XLNet) truncation/padding to a maximum length will be deactivated. + max_target_length (:obj:`int`, `optional`): + Controls the maximum length of decoder inputs (target language texts or summaries) + If left unset or set to :obj:`None`, this will use the max_length value. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): + Activates and controls padding. Accepts the following values: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): + Activates and controls truncation. Accepts the following values: + + * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument + :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not + provided. This will truncate token by token, removing a token from the longest sequence in the pair + if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to + the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with + sequence lengths greater than the model maximum admissible input size). + + Return: + :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: + + - **input_ids** -- List of token ids to be fed to the encoder. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **decoder_input_ids** -- List of token ids to be fed to the decoder. + - **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder. + This does not include causal mask, which is built by the model. + + The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, + will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. + """ if max_length is None: max_length = self.max_len @@ -221,7 +275,7 @@ class MBartTokenizer(XLMRobertaTokenizer): return_tensors=return_tensors, max_length=max_length, padding=padding, - truncation=True, + truncation=truncation, **kwargs, ) if tgt_texts is None: diff --git a/src/transformers/tokenization_marian.py b/src/transformers/tokenization_marian.py index 211dfda8a2..7584531c72 100644 --- a/src/transformers/tokenization_marian.py +++ b/src/transformers/tokenization_marian.py @@ -7,7 +7,9 @@ from typing import Dict, List, Optional, Tuple, Union import sentencepiece +from .file_utils import add_start_docstrings_to_callable from .tokenization_utils import BatchEncoding, PreTrainedTokenizer +from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING vocab_files_names = { @@ -21,7 +23,8 @@ vocab_files_names = { class MarianTokenizer(PreTrainedTokenizer): """Sentencepiece tokenizer for marian. Source and target languages have different SPM models. - The logic is use the relevant source_spm or target_spm to encode txt as pieces, then look up each piece in a vocab dictionary. + The logic is use the relevant source_spm or target_spm to encode txt as pieces, then look up each piece in a + vocab dictionary. Examples:: @@ -29,7 +32,7 @@ class MarianTokenizer(PreTrainedTokenizer): >>> tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de') >>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."] >>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional - >>> batch_enc: BatchEncoding = tok.prepare_translation_batch(src_texts, tgt_texts=tgt_texts) + >>> batch_enc: BatchEncoding = tok.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts) >>> # keys [input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]. >>> # model(**batch) should work """ @@ -122,30 +125,20 @@ class MarianTokenizer(PreTrainedTokenizer): # We don't expect to process pairs, but leave the pair logic for API consistency return token_ids_0 + token_ids_1 + [self.eos_token_id] - def prepare_translation_batch( + @add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) + def prepare_seq2seq_batch( self, src_texts: List[str], tgt_texts: Optional[List[str]] = None, max_length: Optional[int] = None, max_target_length: Optional[int] = None, - pad_to_max_length: bool = True, return_tensors: str = "pt", - truncation_strategy="only_first", + truncation=True, padding="longest", **unused, ) -> BatchEncoding: """Prepare model inputs for translation. For best performance, translate one sentence at a time. - Arguments: - src_texts: list of src language texts - tgt_texts: list of tgt language texts - max_length: (None) defer to config (1024 for mbart-large-en-ro) - pad_to_max_length: (bool) - return_tensors: (str) default "pt" returns pytorch tensors, pass None to return lists. - Returns: - BatchEncoding: with keys [input_ids, attention_mask, decoder_input_ids, decoder_attention_mask] - all shaped bs, seq_len. (BatchEncoding is a dict of string -> tensor or lists). - If no tgt_text is specified, the only keys will be input_ids and attention_mask. """ if "" in src_texts: raise ValueError(f"found empty string in src_texts: {src_texts}") @@ -155,14 +148,15 @@ class MarianTokenizer(PreTrainedTokenizer): add_special_tokens=True, return_tensors=return_tensors, max_length=max_length, - pad_to_max_length=pad_to_max_length, - truncation_strategy=truncation_strategy, + truncation=truncation, padding=padding, ) model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs) if tgt_texts is None: return model_inputs + if max_target_length is not None: + tokenizer_kwargs["max_length"] = max_target_length if max_target_length is not None: tokenizer_kwargs["max_length"] = max_target_length diff --git a/src/transformers/tokenization_pegasus.py b/src/transformers/tokenization_pegasus.py index 94e5810791..e553ad456d 100644 --- a/src/transformers/tokenization_pegasus.py +++ b/src/transformers/tokenization_pegasus.py @@ -16,7 +16,8 @@ from typing import Dict, List, Optional from transformers.tokenization_reformer import ReformerTokenizer -from .tokenization_utils_base import BatchEncoding +from .file_utils import add_start_docstrings_to_callable +from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding class PegasusTokenizer(ReformerTokenizer): @@ -103,6 +104,7 @@ class PegasusTokenizer(ReformerTokenizer): # We don't expect to process pairs, but leave the pair logic for API consistency return token_ids_0 + token_ids_1 + [self.eos_token_id] + @add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) def prepare_seq2seq_batch( self, src_texts: List[str], @@ -116,62 +118,6 @@ class PegasusTokenizer(ReformerTokenizer): """ Prepare model inputs for summarization or translation. - Arguments: - src_texts: (:obj:`list`): - list of documents to summarize or source language texts - tgt_texts: (:obj:`list`, `optional`): - list of tgt language texts or summaries. - max_length (:obj:`int`, `optional`): - Controls the maximum length for encoder inputs (documents to summarize or source language texts) - If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum - length is required by one of the truncation/padding parameters. If the model has no specific maximum - input length (like XLNet) truncation/padding to a maximum length will be deactivated. - max_target_length (:obj:`int`, `optional`): - Controls the maximum length of decoder inputs (target language texts or summaries) - If left unset or set to :obj:`None`, this will use the max_length value. - padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): - Activates and controls padding. Accepts the following values: - - * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a - single sequence if provided). - * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the - maximum acceptable input length for the model if that argument is not provided. - * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of - different lengths). - return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"): - If set, will return tensors instead of list of python integers. Acceptable values are: - - * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. - * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. - * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. - truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): - Activates and controls truncation. Accepts the following values: - - * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument - :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not - provided. This will truncate token by token, removing a token from the longest sequence in the pair - if a pair of sequences (or a batch of pairs) is provided. - * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to - the maximum acceptable input length for the model if that argument is not provided. This will only - truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or - to the maximum acceptable input length for the model if that argument is not provided. This will only - truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. - * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with - sequence lengths greater than the model maximum admissible input size). - - Return: - :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: - - - **input_ids** -- List of token ids to be fed to the encoder. - - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. - - **decoder_input_ids** -- List of token ids to be fed to the decoder. - - **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder. - This does not include causal mask, which is built by the model. - - The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, - will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. - """ if "" in src_texts: raise ValueError(f"found empty string in src_texts: {src_texts}") diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index cf12f2d720..b2951904ef 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1249,6 +1249,67 @@ INIT_TOKENIZER_DOCSTRING = r""" """ +PREPARE_SEQ2SEQ_BATCH_DOCSTRING = """ + + Arguments: + src_texts: (:obj:`list`): + list of documents to summarize or source language texts + tgt_texts: (:obj:`list`, `optional`): + list of tgt language texts or summaries. + max_length (:obj:`int`, `optional`): + Controls the maximum length for encoder inputs (documents to summarize or source language texts) + If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum + length is required by one of the truncation/padding parameters. If the model has no specific maximum + input length (like XLNet) truncation/padding to a maximum length will be deactivated. + max_target_length (:obj:`int`, `optional`): + Controls the maximum length of decoder inputs (target language texts or summaries) + If left unset or set to :obj:`None`, this will use the max_length value. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): + Activates and controls padding. Accepts the following values: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): + Activates and controls truncation. Accepts the following values: + + * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument + :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not + provided. This will truncate token by token, removing a token from the longest sequence in the pair + if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to + the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with + sequence lengths greater than the model maximum admissible input size). + + Return: + :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: + + - **input_ids** -- List of token ids to be fed to the encoder. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **decoder_input_ids** -- List of token ids to be fed to the decoder. + - **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder. + This does not include causal mask, which is built by the model. + + The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, + will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. + +""" + + @add_end_docstrings(INIT_TOKENIZER_DOCSTRING) class PreTrainedTokenizerBase(SpecialTokensMixin): """ diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index 4c936af503..4b49a8c470 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -97,7 +97,7 @@ class MarianIntegrationTest(unittest.TestCase): self.assertListEqual(self.expected_text, generated_words) def translate_src_text(self, **tokenizer_kwargs): - model_inputs = self.tokenizer.prepare_translation_batch(src_texts=self.src_text, **tokenizer_kwargs).to( + model_inputs = self.tokenizer.prepare_seq2seq_batch(src_texts=self.src_text, **tokenizer_kwargs).to( torch_device ) self.assertEqual(self.model.device, model_inputs.input_ids.device) @@ -114,7 +114,7 @@ class TestMarian_EN_DE_More(MarianIntegrationTest): src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."] expected_ids = [38, 121, 14, 697, 38848, 0] - model_inputs: dict = self.tokenizer.prepare_translation_batch(src, tgt_texts=tgt).to(torch_device) + model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt).to(torch_device) self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist()) desired_keys = { @@ -131,12 +131,12 @@ class TestMarian_EN_DE_More(MarianIntegrationTest): def test_unk_support(self): t = self.tokenizer - ids = t.prepare_translation_batch(["||"]).to(torch_device).input_ids[0].tolist() + ids = t.prepare_seq2seq_batch(["||"]).to(torch_device).input_ids[0].tolist() expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id] self.assertEqual(expected, ids) def test_pad_not_split(self): - input_ids_w_pad = self.tokenizer.prepare_translation_batch(["I am a small frog "]).input_ids[0].tolist() + input_ids_w_pad = self.tokenizer.prepare_seq2seq_batch(["I am a small frog "]).input_ids[0].tolist() expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad self.assertListEqual(expected_w_pad, input_ids_w_pad) @@ -229,7 +229,7 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest): normalized = self.tokenizer.normalize("") self.assertIsInstance(normalized, str) with self.assertRaises(ValueError): - self.tokenizer.prepare_translation_batch([""]) + self.tokenizer.prepare_seq2seq_batch([""]) def test_pipeline(self): device = 0 if torch_device == "cuda" else -1 diff --git a/tests/test_modeling_mbart.py b/tests/test_modeling_mbart.py index 602514b158..bf7bb75c0e 100644 --- a/tests/test_modeling_mbart.py +++ b/tests/test_modeling_mbart.py @@ -82,7 +82,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest): @slow def test_enro_generate(self): - batch: BatchEncoding = self.tokenizer.prepare_translation_batch(self.src_text).to(torch_device) + batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text).to(torch_device) translated_tokens = self.model.generate(**batch) decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) self.assertEqual(self.tgt_text[0], decoded[0]) @@ -134,7 +134,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest): @unittest.skip("This test is broken, still generates english") def test_cc25_generate(self): - inputs = self.tokenizer.prepare_translation_batch([self.src_text[0]]).to(torch_device) + inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]]).to(torch_device) translated_tokens = self.model.generate( input_ids=inputs["input_ids"].to(torch_device), decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"], @@ -144,7 +144,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest): @slow def test_fill_mask(self): - inputs = self.tokenizer.prepare_translation_batch(["One of the best I ever read!"]).to(torch_device) + inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best I ever read!"]).to(torch_device) outputs = self.model.generate( inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1 ) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 4b841f850e..3b485fc72d 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1522,3 +1522,37 @@ class TokenizerTesterMixin: if batch_encoded_sequence_fast is None: raise ValueError("Cannot convert list to numpy tensor on batch_encode_plus() (fast)") + + @require_torch + def test_prepare_seq2seq_batch(self): + tokenizer = self.get_tokenizer() + + if not hasattr(tokenizer, "prepare_seq2seq_batch"): + return + # Longer text that will definitely require truncation. + src_text = [ + " UN Chief Says There Is No Military Solution in Syria", + " Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.", + ] + tgt_text = [ + "Şeful ONU declară că nu există o soluţie militară în Siria", + "Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei " + 'pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu ' + "vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.", + ] + batch = tokenizer.prepare_seq2seq_batch( + src_texts=src_text, tgt_texts=tgt_text, max_length=3, max_target_length=10, return_tensors="pt" + ) + self.assertEqual(batch.input_ids.shape[1], 3) + self.assertEqual(batch.decoder_input_ids.shape[1], 10) + # max_target_length will default to max_length if not specified + batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3) + self.assertEqual(batch.input_ids.shape[1], 3) + self.assertEqual(batch.decoder_input_ids.shape[1], 3) + + batch_encoder_only = tokenizer.prepare_seq2seq_batch( + src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt" + ) + self.assertEqual(batch_encoder_only.input_ids.shape[1], 3) + self.assertEqual(batch_encoder_only.attention_mask.shape[1], 3) + self.assertNotIn("decoder_input_ids", batch_encoder_only) diff --git a/tests/test_tokenization_marian.py b/tests/test_tokenization_marian.py index 693314bfc4..4948dffb18 100644 --- a/tests/test_tokenization_marian.py +++ b/tests/test_tokenization_marian.py @@ -64,7 +64,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): def test_tokenizer_equivalence_en_de(self): en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de") - batch = en_de_tokenizer.prepare_translation_batch(["I am a small frog"], return_tensors=None) + batch = en_de_tokenizer.prepare_seq2seq_batch(["I am a small frog"], return_tensors=None) self.assertIsInstance(batch, BatchEncoding) expected = [38, 121, 14, 697, 38848, 0] self.assertListEqual(expected, batch.input_ids[0]) @@ -78,16 +78,12 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): def test_outputs_not_longer_than_maxlen(self): tok = self.get_tokenizer() - batch = tok.prepare_translation_batch( - ["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK - ) + batch = tok.prepare_seq2seq_batch(["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK) self.assertIsInstance(batch, BatchEncoding) self.assertEqual(batch.input_ids.shape, (2, 512)) def test_outputs_can_be_shorter(self): tok = self.get_tokenizer() - batch_smaller = tok.prepare_translation_batch( - ["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK - ) + batch_smaller = tok.prepare_seq2seq_batch(["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK) self.assertIsInstance(batch_smaller, BatchEncoding) self.assertEqual(batch_smaller.input_ids.shape, (2, 10)) diff --git a/tests/test_tokenization_mbart.py b/tests/test_tokenization_mbart.py index 74bfd5b5bf..d8b1ae18f4 100644 --- a/tests/test_tokenization_mbart.py +++ b/tests/test_tokenization_mbart.py @@ -123,8 +123,8 @@ class MBartEnroIntegrationTest(unittest.TestCase): self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004) self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020) - def test_enro_tokenizer_prepare_translation_batch(self): - batch = self.tokenizer.prepare_translation_batch( + def test_enro_tokenizer_prepare_seq2seq_batch(self): + batch = self.tokenizer.prepare_seq2seq_batch( self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), ) self.assertIsInstance(batch, BatchEncoding) @@ -140,13 +140,13 @@ class MBartEnroIntegrationTest(unittest.TestCase): def test_max_target_length(self): - batch = self.tokenizer.prepare_translation_batch( + batch = self.tokenizer.prepare_seq2seq_batch( self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10 ) self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.decoder_input_ids.shape[1], 10) # max_target_length will default to max_length if not specified - batch = self.tokenizer.prepare_translation_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3) + batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3) self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.decoder_input_ids.shape[1], 3) @@ -166,7 +166,7 @@ class MBartEnroIntegrationTest(unittest.TestCase): src_text = ["this is gunna be a long sentence " * 20] assert isinstance(src_text[0], str) desired_max_length = 10 - ids = self.tokenizer.prepare_translation_batch( + ids = self.tokenizer.prepare_seq2seq_batch( src_text, return_tensors=None, max_length=desired_max_length ).input_ids[0] self.assertEqual(ids[-2], 2)