Refactor prepare_seq2seq_batch (#9524)

* Add target contextmanager and rework prepare_seq2seq_batch

* Fix tests, treat BART and Barthez

* Add last tokenizers

* Fix test

* Set src token before calling the superclass

* Remove special behavior for T5

* Remove needless imports

* Remove needless asserts
This commit is contained in:
Sylvain Gugger
2021-01-12 18:19:38 -05:00
committed by GitHub
parent e6ecef711e
commit 063d8d27f4
24 changed files with 169 additions and 700 deletions

View File

@@ -16,8 +16,7 @@
import os
from typing import List, Optional
from ...file_utils import add_start_docstrings
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging
from .configuration_rag import RagConfig
@@ -63,42 +62,18 @@ class RagTokenizer:
def batch_decode(self, *args, **kwargs):
return self.generator.batch_decode(*args, **kwargs)
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = None,
truncation=True,
**kwargs,
) -> BatchEncoding:
if max_length is None:
max_length = self.question_encoder.model_max_length
model_inputs: BatchEncoding = self.question_encoder(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = self.generator.model_max_length
labels = self.generator(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)["input_ids"]
model_inputs["labels"] = labels
return model_inputs
return super().prepare_seq2seq_batch(
src_texts, tgt_texts, max_length=max_length, max_target_length=max_target_length, **kwargs
)