Refactor prepare_seq2seq_batch (#9524)
* Add target contextmanager and rework prepare_seq2seq_batch * Fix tests, treat BART and Barthez * Add last tokenizers * Fix test * Set src token before calling the superclass * Remove special behavior for T5 * Remove needless imports * Remove needless asserts
This commit is contained in:
@@ -16,8 +16,7 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_rag import RagConfig
|
||||
|
||||
@@ -63,42 +62,18 @@ class RagTokenizer:
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
return self.generator.batch_decode(*args, **kwargs)
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.question_encoder.model_max_length
|
||||
model_inputs: BatchEncoding = self.question_encoder(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = self.generator.model_max_length
|
||||
labels = self.generator(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
return super().prepare_seq2seq_batch(
|
||||
src_texts, tgt_texts, max_length=max_length, max_target_length=max_target_length, **kwargs
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user