Deprecate prepare_seq2seq_batch (#10287)
* Deprecate prepare_seq2seq_batch * Fix last tests * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com> * More review comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -56,7 +56,7 @@ FSMTTokenizer
|
||||
|
||||
.. autoclass:: transformers.FSMTTokenizer
|
||||
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
||||
create_token_type_ids_from_sequences, prepare_seq2seq_batch, save_vocabulary
|
||||
create_token_type_ids_from_sequences, save_vocabulary
|
||||
|
||||
|
||||
FSMTModel
|
||||
|
||||
@@ -76,27 +76,29 @@ require 3 character language codes:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from transformers import MarianMTModel, MarianTokenizer
|
||||
src_text = [
|
||||
'>>fra<< this is a sentence in english that we want to translate to french',
|
||||
'>>por<< This should go to portuguese',
|
||||
'>>esp<< And this to Spanish'
|
||||
]
|
||||
>>> from transformers import MarianMTModel, MarianTokenizer
|
||||
>>> src_text = [
|
||||
... '>>fra<< this is a sentence in english that we want to translate to french',
|
||||
... '>>por<< This should go to portuguese',
|
||||
... '>>esp<< And this to Spanish'
|
||||
>>> ]
|
||||
|
||||
model_name = 'Helsinki-NLP/opus-mt-en-roa'
|
||||
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
print(tokenizer.supported_language_codes)
|
||||
model = MarianMTModel.from_pretrained(model_name)
|
||||
translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt"))
|
||||
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
|
||||
# ["c'est une phrase en anglais que nous voulons traduire en français",
|
||||
# 'Isto deve ir para o português.',
|
||||
# 'Y esto al español']
|
||||
>>> model_name = 'Helsinki-NLP/opus-mt-en-roa'
|
||||
>>> tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
>>> print(tokenizer.supported_language_codes)
|
||||
['>>zlm_Latn<<', '>>mfe<<', '>>hat<<', '>>pap<<', '>>ast<<', '>>cat<<', '>>ind<<', '>>glg<<', '>>wln<<', '>>spa<<', '>>fra<<', '>>ron<<', '>>por<<', '>>ita<<', '>>oci<<', '>>arg<<', '>>min<<']
|
||||
|
||||
>>> model = MarianMTModel.from_pretrained(model_name)
|
||||
>>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True))
|
||||
>>> [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
|
||||
["c'est une phrase en anglais que nous voulons traduire en français",
|
||||
'Isto deve ir para o português.',
|
||||
'Y esto al español']
|
||||
|
||||
|
||||
|
||||
|
||||
Code to see available pretrained models:
|
||||
Here is the code to see all available pretrained models on the hub:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -147,21 +149,22 @@ Example of translating english to many romance languages, using old-style 2 char
|
||||
|
||||
.. code-block::python
|
||||
|
||||
from transformers import MarianMTModel, MarianTokenizer
|
||||
src_text = [
|
||||
'>>fr<< this is a sentence in english that we want to translate to french',
|
||||
'>>pt<< This should go to portuguese',
|
||||
'>>es<< And this to Spanish'
|
||||
]
|
||||
>>> from transformers import MarianMTModel, MarianTokenizer
|
||||
>>> src_text = [
|
||||
... '>>fr<< this is a sentence in english that we want to translate to french',
|
||||
... '>>pt<< This should go to portuguese',
|
||||
... '>>es<< And this to Spanish'
|
||||
>>> ]
|
||||
|
||||
model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
|
||||
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
print(tokenizer.supported_language_codes)
|
||||
>>> model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
|
||||
>>> tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
|
||||
model = MarianMTModel.from_pretrained(model_name)
|
||||
translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt"))
|
||||
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
|
||||
# ["c'est une phrase en anglais que nous voulons traduire en français", 'Isto deve ir para o português.', 'Y esto al español']
|
||||
>>> model = MarianMTModel.from_pretrained(model_name)
|
||||
>>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True))
|
||||
>>> tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
|
||||
["c'est une phrase en anglais que nous voulons traduire en français",
|
||||
'Isto deve ir para o português.',
|
||||
'Y esto al español']
|
||||
|
||||
|
||||
|
||||
@@ -176,7 +179,7 @@ MarianTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MarianTokenizer
|
||||
:members: prepare_seq2seq_batch
|
||||
:members: as_target_tokenizer
|
||||
|
||||
|
||||
MarianModel
|
||||
|
||||
@@ -34,22 +34,31 @@ The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/ma
|
||||
Training of MBart
|
||||
_______________________________________________________________________________________________________________________
|
||||
|
||||
MBart is a multilingual encoder-decoder (seq-to-seq) model primarily intended for translation task. As the model is
|
||||
multilingual it expects the sequences in a different format. A special language id token is added in both the source
|
||||
and target text. The source text format is :obj:`X [eos, src_lang_code]` where :obj:`X` is the source text. The target
|
||||
text format is :obj:`[tgt_lang_code] X [eos]`. :obj:`bos` is never used.
|
||||
MBart is a multilingual encoder-decoder (sequence-to-sequence) model primarily intended for translation task. As the
|
||||
model is multilingual it expects the sequences in a different format. A special language id token is added in both the
|
||||
source and target text. The source text format is :obj:`X [eos, src_lang_code]` where :obj:`X` is the source text. The
|
||||
target text format is :obj:`[tgt_lang_code] X [eos]`. :obj:`bos` is never used.
|
||||
|
||||
The :meth:`~transformers.MBartTokenizer.prepare_seq2seq_batch` handles this automatically and should be used to encode
|
||||
the sequences for sequence-to-sequence fine-tuning.
|
||||
The regular :meth:`~transformers.MBartTokenizer.__call__` will encode source text format, and it should be wrapped
|
||||
inside the context manager :meth:`~transformers.MBartTokenizer.as_target_tokenizer` to encode target text format.
|
||||
|
||||
- Supervised training
|
||||
|
||||
.. code-block::
|
||||
|
||||
example_english_phrase = "UN Chief Says There Is No Military Solution in Syria"
|
||||
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
batch = tokenizer.prepare_seq2seq_batch(example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt")
|
||||
model(input_ids=batch['input_ids'], labels=batch['labels']) # forward pass
|
||||
>>> from transformers import MBartForConditionalGeneration, MBartTokenizer
|
||||
|
||||
>>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro")
|
||||
>>> example_english_phrase = "UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
|
||||
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt", src_lang="en_XX", tgt_lang="ro_RO")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
|
||||
|
||||
>>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
|
||||
>>> # forward pass
|
||||
>>> model(**inputs, labels=batch['labels'])
|
||||
|
||||
- Generation
|
||||
|
||||
@@ -58,14 +67,14 @@ the sequences for sequence-to-sequence fine-tuning.
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import MBartForConditionalGeneration, MBartTokenizer
|
||||
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
|
||||
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro")
|
||||
article = "UN Chief Says There Is No Military Solution in Syria"
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], src_lang="en_XX", return_tensors="pt")
|
||||
translated_tokens = model.generate(**batch, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"])
|
||||
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||
assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> from transformers import MBartForConditionalGeneration, MBartTokenizer
|
||||
|
||||
>>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro", src_lang="en_XX")
|
||||
>>> article = "UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> inputs = tokenizer(article, return_tensors="pt")
|
||||
>>> translated_tokens = model.generate(**inputs, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"])
|
||||
>>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
||||
"Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
|
||||
|
||||
Overview of MBart-50
|
||||
@@ -160,7 +169,7 @@ MBartTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MBartTokenizer
|
||||
:members: build_inputs_with_special_tokens, prepare_seq2seq_batch
|
||||
:members: as_target_tokenizer, build_inputs_with_special_tokens
|
||||
|
||||
|
||||
MBartTokenizerFast
|
||||
|
||||
@@ -78,20 +78,20 @@ Usage Example
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
|
||||
import torch
|
||||
src_text = [
|
||||
""" PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
]
|
||||
>>> from transformers import PegasusForConditionalGeneration, PegasusTokenizer
|
||||
>>> import torch
|
||||
>>> src_text = [
|
||||
... """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
>>> ]
|
||||
|
||||
model_name = 'google/pegasus-xsum'
|
||||
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
tokenizer = PegasusTokenizer.from_pretrained(model_name)
|
||||
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, truncation=True, padding='longest', return_tensors="pt").to(torch_device)
|
||||
translated = model.generate(**batch)
|
||||
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
||||
assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers."
|
||||
>>> model_name = 'google/pegasus-xsum'
|
||||
>>> device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
>>> tokenizer = PegasusTokenizer.from_pretrained(model_name)
|
||||
>>> model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
|
||||
>>> batch = tokenizer(src_text, truncation=True, padding='longest', return_tensors="pt").to(torch_device)
|
||||
>>> translated = model.generate(**batch)
|
||||
>>> tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
|
||||
>>> assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers."
|
||||
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ PegasusTokenizer
|
||||
warning: ``add_tokens`` does not work at the moment.
|
||||
|
||||
.. autoclass:: transformers.PegasusTokenizer
|
||||
:members: __call__, prepare_seq2seq_batch
|
||||
:members:
|
||||
|
||||
|
||||
PegasusTokenizerFast
|
||||
|
||||
@@ -56,7 +56,7 @@ RagTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RagTokenizer
|
||||
:members: prepare_seq2seq_batch
|
||||
:members:
|
||||
|
||||
|
||||
Rag specific outputs
|
||||
|
||||
@@ -104,7 +104,7 @@ T5Tokenizer
|
||||
|
||||
.. autoclass:: transformers.T5Tokenizer
|
||||
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
||||
create_token_type_ids_from_sequences, prepare_seq2seq_batch, save_vocabulary
|
||||
create_token_type_ids_from_sequences, save_vocabulary
|
||||
|
||||
|
||||
T5TokenizerFast
|
||||
|
||||
Reference in New Issue
Block a user