Tokenizers should be framework agnostic (#8599)
* Tokenizers should be framework agnostic * Run the slow tests * Not testing * Fix documentation * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
8
.github/workflows/self-push.yml
vendored
8
.github/workflows/self-push.yml
vendored
@@ -16,7 +16,7 @@ on:
|
||||
|
||||
jobs:
|
||||
run_tests_torch_gpu:
|
||||
runs-on: [self-hosted, single-gpu]
|
||||
runs-on: [self-hosted, gpu, single-gpu]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Python version
|
||||
@@ -86,7 +86,7 @@ jobs:
|
||||
|
||||
|
||||
run_tests_tf_gpu:
|
||||
runs-on: [self-hosted, single-gpu]
|
||||
runs-on: [self-hosted, gpu, single-gpu]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Python version
|
||||
@@ -154,7 +154,7 @@ jobs:
|
||||
path: reports
|
||||
|
||||
run_tests_torch_multi_gpu:
|
||||
runs-on: [self-hosted, multi-gpu]
|
||||
runs-on: [self-hosted, gpu, multi-gpu]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Python version
|
||||
@@ -213,7 +213,7 @@ jobs:
|
||||
path: reports
|
||||
|
||||
run_tests_tf_multi_gpu:
|
||||
runs-on: [self-hosted, multi-gpu]
|
||||
runs-on: [self-hosted, gpu, multi-gpu]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Python version
|
||||
|
||||
9
.github/workflows/self-scheduled.yml
vendored
9
.github/workflows/self-scheduled.yml
vendored
@@ -9,13 +9,14 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- ci_*
|
||||
- framework-agnostic-tokenizers
|
||||
repository_dispatch:
|
||||
schedule:
|
||||
- cron: "0 0 * * *"
|
||||
|
||||
jobs:
|
||||
run_all_tests_torch_gpu:
|
||||
runs-on: [self-hosted, single-gpu]
|
||||
runs-on: [self-hosted, gpu, single-gpu]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
@@ -109,7 +110,7 @@ jobs:
|
||||
|
||||
|
||||
run_all_tests_tf_gpu:
|
||||
runs-on: [self-hosted, single-gpu]
|
||||
runs-on: [self-hosted, gpu, single-gpu]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
@@ -188,7 +189,7 @@ jobs:
|
||||
path: reports
|
||||
|
||||
run_all_tests_torch_multi_gpu:
|
||||
runs-on: [self-hosted, multi-gpu]
|
||||
runs-on: [self-hosted, gpu, multi-gpu]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
@@ -279,7 +280,7 @@ jobs:
|
||||
path: reports
|
||||
|
||||
run_all_tests_tf_multi_gpu:
|
||||
runs-on: [self-hosted, multi-gpu]
|
||||
runs-on: [self-hosted, gpu, multi-gpu]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ require 3 character language codes:
|
||||
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
print(tokenizer.supported_language_codes)
|
||||
model = MarianMTModel.from_pretrained(model_name)
|
||||
translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text))
|
||||
translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt"))
|
||||
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
|
||||
# ["c'est une phrase en anglais que nous voulons traduire en français",
|
||||
# 'Isto deve ir para o português.',
|
||||
@@ -150,7 +150,7 @@ Example of translating english to many romance languages, using old-style 2 char
|
||||
print(tokenizer.supported_language_codes)
|
||||
|
||||
model = MarianMTModel.from_pretrained(model_name)
|
||||
translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text))
|
||||
translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt"))
|
||||
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
|
||||
# ["c'est une phrase en anglais que nous voulons traduire en français", 'Isto deve ir para o português.', 'Y esto al español']
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ the sequences for sequence-to-sequence fine-tuning.
|
||||
|
||||
example_english_phrase = "UN Chief Says There Is No Military Solution in Syria"
|
||||
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
batch = tokenizer.prepare_seq2seq_batch(example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian)
|
||||
batch = tokenizer.prepare_seq2seq_batch(example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt")
|
||||
model(input_ids=batch['input_ids'], labels=batch['labels']) # forward pass
|
||||
|
||||
- Generation
|
||||
@@ -58,7 +58,7 @@ the sequences for sequence-to-sequence fine-tuning.
|
||||
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
|
||||
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro")
|
||||
article = "UN Chief Says There Is No Military Solution in Syria"
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], src_lang="en_XX")
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], src_lang="en_XX", return_tensors="pt")
|
||||
translated_tokens = model.generate(**batch, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"])
|
||||
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||
assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
|
||||
@@ -78,7 +78,7 @@ Usage Example
|
||||
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
tokenizer = PegasusTokenizer.from_pretrained(model_name)
|
||||
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, truncation=True, padding='longest').to(torch_device)
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, truncation=True, padding='longest', return_tensors="pt").to(torch_device)
|
||||
translated = model.generate(**batch)
|
||||
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
||||
assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers."
|
||||
|
||||
@@ -11,7 +11,7 @@ tokenizer = PegasusTokenizer.from_pretrained(model_name)
|
||||
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
|
||||
|
||||
def get_response(input_text,num_return_sequences):
|
||||
batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=True,padding='longest',max_length=60).to(torch_device)
|
||||
batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
|
||||
translated = model.generate(**batch,max_length=60,num_beams=10, num_return_sequences=num_return_sequences, temperature=1.5)
|
||||
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
||||
return tgt_text
|
||||
|
||||
@@ -12,7 +12,7 @@ model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_dev
|
||||
|
||||
def get_answer(question, context):
|
||||
input_text = "question: %s text: %s" % (question,context)
|
||||
batch = tokenizer.prepare_seq2seq_batch([input_text], truncation=True, padding='longest').to(torch_device)
|
||||
batch = tokenizer.prepare_seq2seq_batch([input_text], truncation=True, padding='longest', return_tensors="pt").to(torch_device)
|
||||
translated = model.generate(**batch)
|
||||
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
||||
return tgt_text[0]
|
||||
|
||||
@@ -58,7 +58,7 @@ tiny_model = FSMTForConditionalGeneration(config)
|
||||
print(f"num of params {tiny_model.num_parameters()}")
|
||||
|
||||
# Test
|
||||
batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"])
|
||||
batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"], return_tensors="pt")
|
||||
outputs = tiny_model(**batch)
|
||||
|
||||
print("test output:", len(outputs.logits[0]))
|
||||
|
||||
@@ -29,7 +29,7 @@ tiny_model = FSMTForConditionalGeneration(config)
|
||||
print(f"num of params {tiny_model.num_parameters()}")
|
||||
|
||||
# Test
|
||||
batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"])
|
||||
batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"], return_tensors="pt")
|
||||
outputs = tiny_model(**batch)
|
||||
|
||||
print("test output:", len(outputs.logits[0]))
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from transformers import add_start_docstrings
|
||||
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...utils import logging
|
||||
from ..roberta.tokenization_roberta import RobertaTokenizer
|
||||
|
||||
@@ -54,6 +56,7 @@ class BartTokenizer(RobertaTokenizer):
|
||||
"merges_file": {m: merges_url for m in _all_bart_models},
|
||||
}
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
@@ -61,70 +64,10 @@ class BartTokenizer(RobertaTokenizer):
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "None",
|
||||
return_tensors: str = None,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
r"""
|
||||
|
||||
Prepare a batch that can be passed directly to an instance of :class:`~transformers.BartModel`.
|
||||
|
||||
Args:
|
||||
src_texts: (:obj:`List[str]`):
|
||||
List of documents to summarize or source language texts.
|
||||
tgt_texts: (:obj:`List[str]`, `optional`):
|
||||
List of summaries or target language texts.
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts). If
|
||||
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
|
||||
is required by one of the truncation/padding parameters. If the model has no specific maximum input
|
||||
length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||
max_target_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length of decoder inputs (target language texts or summaries). If left unset or
|
||||
set to :obj:`None`, this will use the max_length value.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||
if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||
sequence lengths greater than the model maximum admissible input size).
|
||||
**kwargs:
|
||||
Additional keyword arguments passed along to :obj:`self.__call__`.
|
||||
|
||||
Returns:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **labels** -- List of token ids for tgt_texts
|
||||
|
||||
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
|
||||
Otherwise, input_ids, attention_mask will be the only keys.
|
||||
"""
|
||||
kwargs.pop("src_lang", None)
|
||||
kwargs.pop("tgt_lang", None)
|
||||
if max_length is None:
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from transformers import add_start_docstrings
|
||||
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...utils import logging
|
||||
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
|
||||
from .tokenization_bart import BartTokenizer
|
||||
@@ -49,6 +51,7 @@ class BartTokenizerFast(RobertaTokenizerFast):
|
||||
}
|
||||
slow_tokenizer_class = BartTokenizer
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
@@ -56,72 +59,10 @@ class BartTokenizerFast(RobertaTokenizerFast):
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "None",
|
||||
return_tensors: Optional[str] = None,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
r"""
|
||||
|
||||
Prepare a batch that can be passed directly to an instance of :class:`~transformers.BartModel`.
|
||||
|
||||
Args:
|
||||
src_texts: (:obj:`List[str]`):
|
||||
List of documents to summarize or source language texts.
|
||||
tgt_texts: (:obj:`List[str]`, `optional`):
|
||||
List of summaries or target language texts.
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts). If
|
||||
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
|
||||
is required by one of the truncation/padding parameters. If the model has no specific maximum input
|
||||
length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||
max_target_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length of decoder inputs (target language texts or summaries). If left unset or
|
||||
set to :obj:`None`, this will use the max_length value.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||
if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||
sequence lengths greater than the model maximum admissible input size).
|
||||
**kwargs:
|
||||
Additional keyword arguments passed along to :obj:`self.__call__`.
|
||||
|
||||
Returns:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
|
||||
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the
|
||||
decoder. This does not include causal mask, which is built by the model.
|
||||
|
||||
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, will only
|
||||
be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
|
||||
"""
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
model_inputs: BatchEncoding = self(
|
||||
|
||||
@@ -491,7 +491,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
return_tensors: str = "pt",
|
||||
return_tensors: Optional[str] = None,
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
**unused,
|
||||
|
||||
@@ -41,7 +41,7 @@ class MarianMTModel(BartForConditionalGeneration):
|
||||
|
||||
>>> model = MarianMTModel.from_pretrained(mname)
|
||||
>>> tok = MarianTokenizer.from_pretrained(mname)
|
||||
>>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text]) # don't need tgt_text for inference
|
||||
>>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text], return_tensors="pt") # don't need tgt_text for inference
|
||||
>>> gen = model.generate(**batch) # for forward pass: model(**batch)
|
||||
>>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the bus stop ?"
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
>>> tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
|
||||
>>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."]
|
||||
>>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional
|
||||
>>> batch_enc: BatchEncoding = tok.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts)
|
||||
>>> batch_enc: BatchEncoding = tok.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, return_tensors="pt")
|
||||
>>> # keys [input_ids, attention_mask, labels].
|
||||
>>> # model(**batch) should work
|
||||
"""
|
||||
@@ -175,7 +175,7 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
return_tensors: str = "pt",
|
||||
return_tensors: Optional[str] = None,
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
**unused,
|
||||
|
||||
@@ -22,7 +22,7 @@ class MBartForConditionalGeneration(BartForConditionalGeneration):
|
||||
>>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
|
||||
>>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro")
|
||||
>>> article = "UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article])
|
||||
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], return_tensors="pt")
|
||||
>>> translated_tokens = model.generate(**batch)
|
||||
>>> translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||
>>> assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
|
||||
@@ -81,7 +81,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> batch: dict = tokenizer.prepare_seq2seq_batch(
|
||||
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
|
||||
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt"
|
||||
... )
|
||||
|
||||
"""
|
||||
@@ -183,7 +183,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
max_target_length: Optional[int] = None,
|
||||
truncation: bool = True,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "pt",
|
||||
return_tensors: Optional[str] = None,
|
||||
add_prefix_space: bool = False, # ignored
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
|
||||
@@ -89,7 +89,7 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> batch: dict = tokenizer.prepare_seq2seq_batch(
|
||||
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
|
||||
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt"
|
||||
... )
|
||||
"""
|
||||
|
||||
@@ -181,7 +181,7 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
max_target_length: Optional[int] = None,
|
||||
truncation: bool = True,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "pt",
|
||||
return_tensors: str = None,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
|
||||
@@ -38,7 +38,7 @@ class PegasusForConditionalGeneration(BartForConditionalGeneration):
|
||||
|
||||
>>> model = PegasusForConditionalGeneration.from_pretrained(mname)
|
||||
>>> tok = PegasusTokenizer.from_pretrained(mname)
|
||||
>>> batch = tok.prepare_seq2seq_batch(src_texts=[PGE_ARTICLE]) # don't need tgt_text for inference
|
||||
>>> batch = tok.prepare_seq2seq_batch(src_texts=[PGE_ARTICLE], return_tensors="pt") # don't need tgt_text for inference
|
||||
>>> gen = model.generate(**batch) # for forward pass: model(**batch)
|
||||
>>> summary: List[str] = tok.batch_decode(gen, skip_special_tokens=True)
|
||||
>>> assert summary == "California's largest electricity provider has turned off power to tens of thousands of customers."
|
||||
|
||||
@@ -134,7 +134,7 @@ class PegasusTokenizer(ReformerTokenizer):
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
return_tensors: str = "pt",
|
||||
return_tensors: str = None,
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
**unused,
|
||||
|
||||
@@ -95,7 +95,7 @@ class PegasusTokenizerFast(ReformerTokenizerFast):
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
return_tensors: str = "pt",
|
||||
return_tensors: str = None,
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
**unused,
|
||||
|
||||
@@ -71,7 +71,7 @@ class RagTokenizer:
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "np",
|
||||
return_tensors: str = None,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
|
||||
@@ -797,7 +797,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
|
||||
@@ -1455,7 +1455,7 @@ PREPARE_SEQ2SEQ_BATCH_DOCSTRING = """
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
|
||||
@@ -132,9 +132,9 @@ class MarianIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(self.expected_text, generated_words)
|
||||
|
||||
def translate_src_text(self, **tokenizer_kwargs):
|
||||
model_inputs = self.tokenizer.prepare_seq2seq_batch(src_texts=self.src_text, **tokenizer_kwargs).to(
|
||||
torch_device
|
||||
)
|
||||
model_inputs = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=self.src_text, return_tensors="pt", **tokenizer_kwargs
|
||||
).to(torch_device)
|
||||
self.assertEqual(self.model.device, model_inputs.input_ids.device)
|
||||
generated_ids = self.model.generate(
|
||||
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
|
||||
@@ -151,7 +151,9 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
||||
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
|
||||
expected_ids = [38, 121, 14, 697, 38848, 0]
|
||||
|
||||
model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt).to(torch_device)
|
||||
model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
|
||||
|
||||
@@ -171,12 +173,16 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
||||
|
||||
def test_unk_support(self):
|
||||
t = self.tokenizer
|
||||
ids = t.prepare_seq2seq_batch(["||"]).to(torch_device).input_ids[0].tolist()
|
||||
ids = t.prepare_seq2seq_batch(["||"], return_tensors="pt").to(torch_device).input_ids[0].tolist()
|
||||
expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id]
|
||||
self.assertEqual(expected, ids)
|
||||
|
||||
def test_pad_not_split(self):
|
||||
input_ids_w_pad = self.tokenizer.prepare_seq2seq_batch(["I am a small frog <pad>"]).input_ids[0].tolist()
|
||||
input_ids_w_pad = (
|
||||
self.tokenizer.prepare_seq2seq_batch(["I am a small frog <pad>"], return_tensors="pt")
|
||||
.input_ids[0]
|
||||
.tolist()
|
||||
)
|
||||
expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad
|
||||
self.assertListEqual(expected_w_pad, input_ids_w_pad)
|
||||
|
||||
@@ -294,7 +300,7 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
|
||||
normalized = self.tokenizer.normalize("")
|
||||
self.assertIsInstance(normalized, str)
|
||||
with self.assertRaises(ValueError):
|
||||
self.tokenizer.prepare_seq2seq_batch([""])
|
||||
self.tokenizer.prepare_seq2seq_batch([""], return_tensors="pt")
|
||||
|
||||
@slow
|
||||
def test_pipeline(self):
|
||||
|
||||
@@ -92,7 +92,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
@slow
|
||||
def test_enro_generate_one(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
["UN Chief Says There Is No Military Solution in Syria"]
|
||||
["UN Chief Says There Is No Military Solution in Syria"], return_tensors="pt"
|
||||
).to(torch_device)
|
||||
translated_tokens = self.model.generate(**batch)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
@@ -101,7 +101,9 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@slow
|
||||
def test_enro_generate_batch(self):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text).to(torch_device)
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
translated_tokens = self.model.generate(**batch)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
assert self.tgt_text == decoded
|
||||
@@ -153,7 +155,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@unittest.skip("This test is broken, still generates english")
|
||||
def test_cc25_generate(self):
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]]).to(torch_device)
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]], return_tensors="pt").to(torch_device)
|
||||
translated_tokens = self.model.generate(
|
||||
input_ids=inputs["input_ids"].to(torch_device),
|
||||
decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"],
|
||||
@@ -163,7 +165,9 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
|
||||
@slow
|
||||
def test_fill_mask(self):
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best <mask> I ever read!"]).to(torch_device)
|
||||
inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best <mask> I ever read!"], return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
outputs = self.model.generate(
|
||||
inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1
|
||||
)
|
||||
|
||||
@@ -1794,7 +1794,7 @@ class TokenizerTesterMixin:
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.labels.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3)
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3, return_tensors="pt")
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.labels.shape[1], 3)
|
||||
|
||||
|
||||
@@ -165,7 +165,6 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
desired_max_length = 10
|
||||
ids = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_text,
|
||||
return_tensors=None,
|
||||
max_length=desired_max_length,
|
||||
).input_ids[0]
|
||||
self.assertEqual(ids[-2], 2)
|
||||
@@ -203,9 +202,7 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
@require_torch
|
||||
def test_enro_tokenizer_prepare_seq2seq_batch(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text,
|
||||
tgt_texts=self.tgt_text,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
@@ -221,13 +218,15 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_seq2seq_max_target_length(self):
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
# max_target_length will default to max_length if not specified
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
|
||||
batch = self.tokenizer.prepare_seq2seq_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=3, return_tensors="pt"
|
||||
)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch.labels, self.tokenizer.pad_token_id)
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
|
||||
|
||||
@@ -61,7 +61,9 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_pegasus_large_seq2seq_truncation(self):
|
||||
src_texts = ["This is going to be way too long." * 150, "short example"]
|
||||
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
|
||||
batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, max_target_length=5)
|
||||
batch = self.pegasus_large_tokenizer.prepare_seq2seq_batch(
|
||||
src_texts, tgt_texts=tgt_texts, max_target_length=5, return_tensors="pt"
|
||||
)
|
||||
assert batch.input_ids.shape == (2, 1024)
|
||||
assert batch.attention_mask.shape == (2, 1024)
|
||||
assert "labels" in batch # because tgt_texts was specified
|
||||
|
||||
Reference in New Issue
Block a user