Replace as_target context managers by direct calls (#18325)

* Preliminary work on tokenizers

* Quality + fix tests

* Treat processors

* Fix pad

* Remove all uses of  in tests, docs and examples

* Replace all as_target_tokenizer

* Fix tests

* Fix quality

* Update examples/flax/image-captioning/run_image_captioning_flax.py

Co-authored-by: amyeroberts <amy@huggingface.co>

* Style

Co-authored-by: amyeroberts <amy@huggingface.co>
This commit is contained in:
Sylvain Gugger
2022-07-29 08:09:09 -04:00
committed by GitHub
parent a64bcb564d
commit 986526a0e4
80 changed files with 725 additions and 550 deletions

View File

@@ -55,9 +55,7 @@ tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="en
src_text = "Life is like a box of chocolates."
tgt_text = "La vie est comme une boîte de chocolat."
model_inputs = tokenizer(src_text, return_tensors="pt")
with tokenizer.as_target_tokenizer():
labels = tokenizer(tgt_text, return_tensors="pt").input_ids
model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
loss = model(**model_inputs, labels=labels) # forward pass
```

View File

@@ -155,7 +155,7 @@ Example of translating english to many romance languages, using old-style 2 char
## MarianTokenizer
[[autodoc]] MarianTokenizer
- as_target_tokenizer
- build_inputs_with_special_tokens
## MarianModel

View File

@@ -34,8 +34,8 @@ model is multilingual it expects the sequences in a different format. A special
source and target text. The source text format is `X [eos, src_lang_code]` where `X` is the source text. The
target text format is `[tgt_lang_code] X [eos]`. `bos` is never used.
The regular [`~MBartTokenizer.__call__`] will encode source text format, and it should be wrapped
inside the context manager [`~MBartTokenizer.as_target_tokenizer`] to encode target text format.
The regular [`~MBartTokenizer.__call__`] will encode source text format passed as first argument or with the `text`
keyword, and target text format passed with the `text_label` keyword argument.
- Supervised training
@@ -46,13 +46,11 @@ inside the context manager [`~MBartTokenizer.as_target_tokenizer`] to encode tar
>>> example_english_phrase = "UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
>>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_romanian, return_tensors="pt")
>>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
>>> # forward pass
>>> model(**inputs, labels=batch["labels"])
>>> model(**inputs)
```
- Generation
@@ -108,11 +106,9 @@ tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_
src_text = " UN Chief Says There Is No Military Solution in Syria"
tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
model_inputs = tokenizer(src_text, return_tensors="pt")
with tokenizer.as_target_tokenizer():
labels = tokenizer(tgt_text, return_tensors="pt").input_ids
model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
model(**model_inputs, labels=labels) # forward pass
model(**model_inputs) # forward pass
```
- Generation
@@ -154,7 +150,6 @@ tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
## MBartTokenizer
[[autodoc]] MBartTokenizer
- as_target_tokenizer
- build_inputs_with_special_tokens
## MBartTokenizerFast

View File

@@ -48,7 +48,6 @@ This model was contributed by [cwkeam](https://huggingface.co/cwkeam). The origi
- save_pretrained
- batch_decode
- decode
- as_target_processor
## MCTCTModel

View File

@@ -91,7 +91,6 @@ UN-Chef sagt, es gibt keine militärische Lösung in Syrien
## NllbTokenizer
[[autodoc]] NllbTokenizer
- as_target_tokenizer
- build_inputs_with_special_tokens
## NllbTokenizerFast

View File

@@ -45,8 +45,9 @@ target text format is `[tgt_lang_code] X [eos]`. `bos` is never used.
However, for fine-tuning, in some cases no language token is provided in cases where a single language is used. Please refer to [the paper](https://arxiv.org/abs/2103.06333) to learn more about this.
In cases where the language code is needed, The regular [`~PLBartTokenizer.__call__`] will encode source text format, and it should be wrapped
inside the context manager [`~PLBartTokenizer.as_target_tokenizer`] to encode target text format.
In cases where the language code is needed, the regular [`~PLBartTokenizer.__call__`] will encode source text format
when you pass texts as the first argument or with the keyword argument `text`, and will encode target text format if
it's passed with the `text_target` keyword argument.
- Supervised training
@@ -56,11 +57,7 @@ inside the context manager [`~PLBartTokenizer.as_target_tokenizer`] to encode ta
>>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-base", src_lang="en_XX", tgt_lang="python")
>>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])"
>>> expected_translation_english = "Returns the maximum value of a b c."
>>> inputs = tokenizer(example_python_phrase, return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(expected_translation_english, return_tensors="pt")
>>> inputs["labels"] = labels["input_ids"]
>>> # forward pass
>>> inputs = tokenizer(example_python_phrase, text_target=expected_translation_english, return_tensors="pt")
>>> model(**inputs)
```
@@ -88,7 +85,6 @@ inside the context manager [`~PLBartTokenizer.as_target_tokenizer`] to encode ta
## PLBartTokenizer
[[autodoc]] PLBartTokenizer
- as_target_tokenizer
- build_inputs_with_special_tokens
## PLBartModel

View File

@@ -107,7 +107,7 @@ speech inputs) and `labels` (which are the `input_ids` of the encoded target seq
>>> labels = tokenizer(ds[0]["text"], return_tensors="pt").input_ids
>>> # the forward function automatically creates the correct decoder_input_ids
>>> loss = model(input_values, labels=labels).loss
>>> loss = model(**input_features).loss
>>> loss.backward()
```

View File

@@ -120,7 +120,6 @@ See the [model hub](https://huggingface.co/models?filter=speech_to_text) to look
- save_pretrained
- batch_decode
- decode
- as_target_processor
## Speech2TextModel

View File

@@ -114,7 +114,6 @@ See [model hub](https://huggingface.co/models?filter=speech2text2) to look for S
- save_pretrained
- batch_decode
- decode
- as_target_processor
## Speech2Text2ForCausalLM

View File

@@ -94,7 +94,6 @@ See the [model hub](https://huggingface.co/models?filter=trocr) to look for TrOC
- save_pretrained
- batch_decode
- decode
- as_target_processor
## TrOCRForCausalLM

View File

@@ -62,7 +62,6 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
- save_pretrained
- batch_decode
- decode
- as_target_processor
## Wav2Vec2ProcessorWithLM
@@ -73,7 +72,6 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
- save_pretrained
- batch_decode
- decode
- as_target_processor
## Wav2Vec2 specific outputs