Deprecate prepare_seq2seq_batch (#10287)
* Deprecate prepare_seq2seq_batch * Fix last tests * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com> * More review comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -522,13 +522,14 @@ MARIAN_GENERATION_EXAMPLE = r"""
|
||||
>>> src = 'fr' # source language
|
||||
>>> trg = 'en' # target language
|
||||
>>> sample_text = "où est l'arrêt de bus ?"
|
||||
>>> mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
|
||||
>>> model_name = f'Helsinki-NLP/opus-mt-{src}-{trg}'
|
||||
|
||||
>>> model = MarianMTModel.from_pretrained(mname)
|
||||
>>> tok = MarianTokenizer.from_pretrained(mname)
|
||||
>>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text], return_tensors="pt") # don't need tgt_text for inference
|
||||
>>> model = MarianMTModel.from_pretrained(model_name)
|
||||
>>> tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
>>> batch = tokenizer([sample_text], return_tensors="pt")
|
||||
>>> gen = model.generate(**batch)
|
||||
>>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the bus stop ?"
|
||||
>>> tokenizer.batch_decode(gen, skip_special_tokens=True)
|
||||
"Where is the bus stop ?"
|
||||
"""
|
||||
|
||||
MARIAN_INPUTS_DOCSTRING = r"""
|
||||
|
||||
@@ -557,13 +557,14 @@ MARIAN_GENERATION_EXAMPLE = r"""
|
||||
>>> src = 'fr' # source language
|
||||
>>> trg = 'en' # target language
|
||||
>>> sample_text = "où est l'arrêt de bus ?"
|
||||
>>> mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
|
||||
>>> model_name = f'Helsinki-NLP/opus-mt-{src}-{trg}'
|
||||
|
||||
>>> model = MarianMTModel.from_pretrained(mname)
|
||||
>>> tok = MarianTokenizer.from_pretrained(mname)
|
||||
>>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text], return_tensors="tf") # don't need tgt_text for inference
|
||||
>>> model = TFMarianMTModel.from_pretrained(model_name)
|
||||
>>> tokenizer = MarianTokenizer.from_pretrained(model_name)
|
||||
>>> batch = tokenizer([sample_text], return_tensors="tf")
|
||||
>>> gen = model.generate(**batch)
|
||||
>>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the bus stop ?"
|
||||
>>> tokenizer.batch_decode(gen, skip_special_tokens=True)
|
||||
"Where is the bus stop ?"
|
||||
"""
|
||||
|
||||
MARIAN_INPUTS_DOCSTRING = r"""
|
||||
|
||||
@@ -80,12 +80,15 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MarianTokenizer
|
||||
>>> tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
|
||||
>>> tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
|
||||
>>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."]
|
||||
>>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional
|
||||
>>> batch_enc = tok.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, return_tensors="pt")
|
||||
>>> # keys [input_ids, attention_mask, labels].
|
||||
>>> # model(**batch) should work
|
||||
>>> inputs = tokenizer(src_texts, return_tensors="pt", padding=True)
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(tgt_texts, return_tensors="pt", padding=True)
|
||||
>>> inputs["labels"] = labels["input_ids"]
|
||||
# keys [input_ids, attention_mask, labels].
|
||||
>>> outputs = model(**inputs) should work
|
||||
"""
|
||||
|
||||
vocab_files_names = vocab_files_names
|
||||
|
||||
@@ -59,30 +59,23 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
"""
|
||||
Construct an MBART tokenizer.
|
||||
|
||||
:class:`~transformers.MBartTokenizer` is a subclass of :class:`~transformers.XLMRobertaTokenizer` and adds a new
|
||||
:meth:`~transformers.MBartTokenizer.prepare_seq2seq_batch`
|
||||
|
||||
Refer to superclass :class:`~transformers.XLMRobertaTokenizer` for usage examples and documentation concerning the
|
||||
:class:`~transformers.MBartTokenizer` is a subclass of :class:`~transformers.XLMRobertaTokenizer`. Refer to
|
||||
superclass :class:`~transformers.XLMRobertaTokenizer` for usage examples and documentation concerning the
|
||||
initialization parameters and other methods.
|
||||
|
||||
.. warning::
|
||||
|
||||
``prepare_seq2seq_batch`` should be used to encode inputs. Other tokenizer methods like ``encode`` do not work
|
||||
properly.
|
||||
|
||||
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and ``<language code>
|
||||
<tokens> <eos>``` for target language documents.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MBartTokenizer
|
||||
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro')
|
||||
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro', src_lang="en_XX", tgt_lang="ro_RO")
|
||||
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> batch: dict = tokenizer.prepare_seq2seq_batch(
|
||||
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt"
|
||||
... )
|
||||
|
||||
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt)
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
|
||||
>>> inputs["labels"] = labels["input_ids"]
|
||||
"""
|
||||
|
||||
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
|
||||
@@ -92,26 +85,38 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
prefix_tokens: List[int] = []
|
||||
suffix_tokens: List[int] = []
|
||||
|
||||
def __init__(self, *args, tokenizer_file=None, **kwargs):
|
||||
super().__init__(*args, tokenizer_file=tokenizer_file, **kwargs)
|
||||
def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
|
||||
super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
|
||||
|
||||
self.sp_model_size = len(self.sp_model)
|
||||
self.lang_code_to_id = {
|
||||
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
|
||||
}
|
||||
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
|
||||
self.cur_lang_code = self.lang_code_to_id["en_XX"]
|
||||
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
|
||||
|
||||
self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
|
||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
|
||||
|
||||
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
||||
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
|
||||
self.tgt_lang = tgt_lang
|
||||
self.set_src_lang_special_tokens(self._src_lang)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token
|
||||
|
||||
@property
|
||||
def src_lang(self) -> str:
|
||||
return self._src_lang
|
||||
|
||||
@src_lang.setter
|
||||
def src_lang(self, new_src_lang: str) -> None:
|
||||
self._src_lang = new_src_lang
|
||||
self.set_src_lang_special_tokens(self._src_lang)
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
@@ -181,7 +186,6 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
) -> BatchEncoding:
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
|
||||
@@ -70,15 +70,9 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
Construct a "fast" MBART tokenizer (backed by HuggingFace's `tokenizers` library). Based on `BPE
|
||||
<https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models>`__.
|
||||
|
||||
:class:`~transformers.MBartTokenizerFast` is a subclass of :class:`~transformers.XLMRobertaTokenizerFast` and adds
|
||||
a new :meth:`~transformers.MBartTokenizerFast.prepare_seq2seq_batch`.
|
||||
|
||||
Refer to superclass :class:`~transformers.XLMRobertaTokenizerFast` for usage examples and documentation concerning
|
||||
the initialization parameters and other methods.
|
||||
|
||||
.. warning::
|
||||
``prepare_seq2seq_batch`` should be used to encode inputs. Other tokenizer methods like ``encode`` do not work
|
||||
properly.
|
||||
:class:`~transformers.MBartTokenizerFast` is a subclass of :class:`~transformers.XLMRobertaTokenizerFast`. Refer to
|
||||
superclass :class:`~transformers.XLMRobertaTokenizerFast` for usage examples and documentation concerning the
|
||||
initialization parameters and other methods.
|
||||
|
||||
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and ``<language code>
|
||||
<tokens> <eos>``` for target language documents.
|
||||
@@ -86,12 +80,13 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MBartTokenizerFast
|
||||
>>> tokenizer = MBartTokenizerFast.from_pretrained('facebook/mbart-large-en-ro')
|
||||
>>> tokenizer = MBartTokenizerFast.from_pretrained('facebook/mbart-large-en-ro', src_lang="en_XX", tgt_lang="ro_RO")
|
||||
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> batch: dict = tokenizer.prepare_seq2seq_batch(
|
||||
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt"
|
||||
... )
|
||||
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt)
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
|
||||
>>> inputs["labels"] = labels["input_ids"]
|
||||
"""
|
||||
|
||||
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
|
||||
@@ -102,14 +97,25 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
prefix_tokens: List[int] = []
|
||||
suffix_tokens: List[int] = []
|
||||
|
||||
def __init__(self, *args, tokenizer_file=None, **kwargs):
|
||||
super().__init__(*args, tokenizer_file=tokenizer_file, **kwargs)
|
||||
|
||||
self.cur_lang_code = self.convert_tokens_to_ids("en_XX")
|
||||
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
|
||||
def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
|
||||
super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
|
||||
|
||||
self.add_special_tokens({"additional_special_tokens": FAIRSEQ_LANGUAGE_CODES})
|
||||
|
||||
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
||||
self.cur_lang_code = self.convert_tokens_to_ids(self._src_lang)
|
||||
self.tgt_lang = tgt_lang
|
||||
self.set_src_lang_special_tokens(self._src_lang)
|
||||
|
||||
@property
|
||||
def src_lang(self) -> str:
|
||||
return self._src_lang
|
||||
|
||||
@src_lang.setter
|
||||
def src_lang(self, new_src_lang: str) -> None:
|
||||
self._src_lang = new_src_lang
|
||||
self.set_src_lang_special_tokens(self._src_lang)
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
@@ -181,7 +187,6 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
) -> BatchEncoding:
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
|
||||
@@ -31,13 +31,17 @@ class MT5Model(T5Model):
|
||||
alongside usage examples.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MT5Model, T5Tokenizer
|
||||
>>> model = MT5Model.from_pretrained("google/mt5-small")
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt")
|
||||
>>> outputs = model(input_ids=batch.input_ids, decoder_input_ids=batch.labels)
|
||||
>>> inputs = tokenizer(article, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(summary, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
|
||||
>>> hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
model_type = "mt5"
|
||||
@@ -59,13 +63,17 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
|
||||
appropriate documentation alongside usage examples.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import MT5ForConditionalGeneration, T5Tokenizer
|
||||
>>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt")
|
||||
>>> outputs = model(**batch)
|
||||
>>> inputs = tokenizer(article, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(summary, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs,labels=labels["input_ids"])
|
||||
>>> loss = outputs.loss
|
||||
"""
|
||||
|
||||
|
||||
@@ -31,15 +31,17 @@ class TFMT5Model(TFT5Model):
|
||||
documentation alongside usage examples.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import TFMT5Model, T5Tokenizer
|
||||
>>> model = TFMT5Model.from_pretrained("google/mt5-small")
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf")
|
||||
>>> batch["decoder_input_ids"] = batch["labels"]
|
||||
>>> del batch["labels"]
|
||||
>>> outputs = model(batch)
|
||||
>>> inputs = tokenizer(article, return_tensors="tf")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(summary, return_tensors="tf")
|
||||
|
||||
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
|
||||
>>> hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
model_type = "mt5"
|
||||
@@ -52,13 +54,17 @@ class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration):
|
||||
appropriate documentation alongside usage examples.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import TFMT5ForConditionalGeneration, T5Tokenizer
|
||||
>>> model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf")
|
||||
>>> outputs = model(batch)
|
||||
>>> inputs = tokenizer(article, return_tensors="tf")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(summary, return_tensors="tf")
|
||||
|
||||
>>> outputs = model(**inputs,labels=labels["input_ids"])
|
||||
>>> loss = outputs.loss
|
||||
"""
|
||||
|
||||
|
||||
@@ -550,10 +550,8 @@ class RagModel(RagPreTrainedModel):
|
||||
>>> # initialize with RagRetriever to do everything in one forward call
|
||||
>>> model = RagModel.from_pretrained("facebook/rag-token-base", retriever=retriever)
|
||||
|
||||
>>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt")
|
||||
>>> input_ids = input_dict["input_ids"]
|
||||
>>> outputs = model(input_ids=input_ids)
|
||||
|
||||
>>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
|
||||
>>> outputs = model(input_ids=inputs["input_ids"])
|
||||
"""
|
||||
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
@@ -752,9 +750,12 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
||||
>>> # initialize with RagRetriever to do everything in one forward call
|
||||
>>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
|
||||
|
||||
>>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt")
|
||||
>>> input_ids = input_dict["input_ids"]
|
||||
>>> outputs = model(input_ids=input_ids, labels=input_dict["labels"])
|
||||
>>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt")
|
||||
>>> input_ids = inputs["input_ids"]
|
||||
>>> labels = targets["input_ids"]
|
||||
>>> outputs = model(input_ids=input_ids, labels=labels)
|
||||
|
||||
>>> # or use retriever separately
|
||||
>>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
|
||||
@@ -764,7 +765,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
||||
>>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
|
||||
>>> doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
|
||||
>>> # 3. Forward to generator
|
||||
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"])
|
||||
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=labels)
|
||||
"""
|
||||
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||
exclude_bos_score = exclude_bos_score if exclude_bos_score is not None else self.config.exclude_bos_score
|
||||
@@ -1203,9 +1204,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
>>> # initialize with RagRetriever to do everything in one forward call
|
||||
>>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
|
||||
|
||||
>>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt")
|
||||
>>> input_ids = input_dict["input_ids"]
|
||||
>>> outputs = model(input_ids=input_ids, labels=input_dict["labels"])
|
||||
>>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt")
|
||||
>>> input_ids = inputs["input_ids"]
|
||||
>>> labels = targets["input_ids"]
|
||||
>>> outputs = model(input_ids=input_ids, labels=labels)
|
||||
|
||||
>>> # or use retriever separately
|
||||
>>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)
|
||||
@@ -1215,7 +1219,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
>>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
|
||||
>>> doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
|
||||
>>> # 3. Forward to generator
|
||||
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"])
|
||||
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=labels)
|
||||
|
||||
>>> # or directly generate
|
||||
>>> generated = model.generate(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for RAG."""
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -88,6 +89,13 @@ class RagTokenizer:
|
||||
truncation: bool = True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
warnings.warn(
|
||||
"`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the "
|
||||
"regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` "
|
||||
"context manager to prepare your targets. See the documentation of your specific tokenizer for more "
|
||||
"details",
|
||||
FutureWarning,
|
||||
)
|
||||
if max_length is None:
|
||||
max_length = self.current_tokenizer.model_max_length
|
||||
model_inputs = self(
|
||||
|
||||
Reference in New Issue
Block a user