Refactor prepare_seq2seq_batch (#9524)
* Add target contextmanager and rework prepare_seq2seq_batch * Fix tests, treat BART and Barthez * Add last tokenizers * Fix test * Set src token before calling the superclass * Remove special behavior for T5 * Remove needless imports * Remove needless asserts
This commit is contained in:
@@ -13,10 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...utils import logging
|
||||
from ..roberta.tokenization_roberta import RobertaTokenizer
|
||||
|
||||
@@ -54,45 +50,3 @@ class BartTokenizer(RobertaTokenizer):
|
||||
"vocab_file": {m: vocab_url for m in _all_bart_models},
|
||||
"merges_file": {m: merges_url for m in _all_bart_models},
|
||||
}
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
kwargs.pop("src_lang", None)
|
||||
kwargs.pop("tgt_lang", None)
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
||||
@@ -13,10 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...utils import logging
|
||||
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
|
||||
from .tokenization_bart import BartTokenizer
|
||||
@@ -49,43 +45,3 @@ class BartTokenizerFast(RobertaTokenizerFast):
|
||||
"tokenizer_file": {m: tokenizer_url for m in _all_bart_models},
|
||||
}
|
||||
slow_tokenizer_class = BartTokenizer
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: Optional[str] = None,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
||||
@@ -21,9 +21,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@@ -264,45 +262,3 @@ class BarthezTokenizer(PreTrainedTokenizer):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "None",
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
kwargs.pop("src_lang", None)
|
||||
kwargs.pop("tgt_lang", None)
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
||||
@@ -19,8 +19,7 @@ import os
|
||||
from shutil import copyfile
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ...file_utils import add_start_docstrings, is_sentencepiece_available
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...file_utils import is_sentencepiece_available
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
|
||||
@@ -228,45 +227,3 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "None",
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
kwargs.pop("src_lang", None)
|
||||
kwargs.pop("tgt_lang", None)
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
||||
@@ -23,9 +23,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import sacremoses as sm
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@@ -484,40 +482,6 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
||||
return len(token_ids_0 + sep) * [0]
|
||||
return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
return_tensors: Optional[str] = None,
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
**unused,
|
||||
) -> BatchEncoding:
|
||||
if type(src_texts) is not list:
|
||||
raise ValueError("src_texts is expected to be a list")
|
||||
if "" in src_texts:
|
||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
||||
|
||||
tokenizer_kwargs = dict(
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
)
|
||||
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
||||
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
if max_target_length is not None:
|
||||
tokenizer_kwargs["max_length"] = max_target_length
|
||||
|
||||
model_inputs["labels"] = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
||||
return model_inputs
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
|
||||
@@ -15,15 +15,14 @@
|
||||
import json
|
||||
import re
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import sentencepiece
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
|
||||
vocab_files_names = {
|
||||
@@ -182,40 +181,15 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
return_tensors: Optional[str] = None,
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
**unused,
|
||||
) -> BatchEncoding:
|
||||
if "" in src_texts:
|
||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
||||
self.current_spm = self.spm_source
|
||||
src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much
|
||||
tokenizer_kwargs = dict(
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
)
|
||||
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
||||
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
if max_target_length is not None:
|
||||
tokenizer_kwargs["max_length"] = max_target_length
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.current_spm = self.spm_target
|
||||
model_inputs["labels"] = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
||||
yield
|
||||
self.current_spm = self.spm_source
|
||||
return model_inputs
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
|
||||
@@ -13,11 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from ...utils import logging
|
||||
from ..xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
|
||||
@@ -172,52 +171,28 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
src_lang: str = "en_XX",
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
tgt_lang: str = "ro_RO",
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
truncation: bool = True,
|
||||
padding: str = "longest",
|
||||
return_tensors: Optional[str] = None,
|
||||
add_prefix_space: bool = False, # ignored
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
self.set_src_lang_special_tokens(src_lang)
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
self.set_tgt_lang_special_tokens(tgt_lang)
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
|
||||
return model_inputs
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
yield
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
||||
|
||||
@@ -13,13 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional
|
||||
|
||||
from tokenizers import processors
|
||||
|
||||
from ...file_utils import add_start_docstrings, is_sentencepiece_available
|
||||
from ...file_utils import is_sentencepiece_available
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from ...utils import logging
|
||||
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
|
||||
|
||||
@@ -171,51 +171,28 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
src_lang: str = "en_XX",
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
tgt_lang: str = "ro_RO",
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
truncation: bool = True,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
self.set_src_lang_special_tokens(src_lang)
|
||||
model_inputs: BatchEncoding = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
self.set_tgt_lang_special_tokens(tgt_lang)
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
|
||||
return model_inputs
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
yield
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
||||
|
||||
@@ -18,9 +18,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@@ -250,36 +248,6 @@ class PegasusTokenizer(PreTrainedTokenizer):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
return_tensors: str = None,
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
**unused,
|
||||
) -> BatchEncoding:
|
||||
if "" in src_texts:
|
||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
||||
tokenizer_kwargs = dict(
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
)
|
||||
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
if max_target_length is not None:
|
||||
tokenizer_kwargs["max_length"] = max_target_length
|
||||
labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
|
||||
@@ -19,8 +19,7 @@ import os
|
||||
from shutil import copyfile
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ...file_utils import add_start_docstrings, is_sentencepiece_available
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...file_utils import is_sentencepiece_available
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
|
||||
@@ -188,36 +187,6 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
return_tensors: str = None,
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
**unused,
|
||||
) -> BatchEncoding:
|
||||
if "" in src_texts:
|
||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
||||
tokenizer_kwargs = dict(
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
)
|
||||
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
if max_target_length is not None:
|
||||
tokenizer_kwargs["max_length"] = max_target_length
|
||||
labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
|
||||
@@ -17,9 +17,7 @@ import collections
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer
|
||||
|
||||
@@ -288,43 +286,3 @@ class ProphetNetTokenizer(PreTrainedTokenizer):
|
||||
return token_ids_0 + [self.sep_token_id]
|
||||
sep = [self.sep_token_id]
|
||||
return token_ids_0 + sep + token_ids_1 + sep
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
truncation: bool = True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
model_inputs = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
labels_and_decoder_mask = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
|
||||
return model_inputs
|
||||
|
||||
@@ -16,8 +16,7 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_rag import RagConfig
|
||||
|
||||
@@ -63,42 +62,18 @@ class RagTokenizer:
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
return self.generator.batch_decode(*args, **kwargs)
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.question_encoder.model_max_length
|
||||
model_inputs: BatchEncoding = self.question_encoder(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = self.generator.model_max_length
|
||||
labels = self.generator(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)["input_ids"]
|
||||
model_inputs["labels"] = labels
|
||||
return model_inputs
|
||||
return super().prepare_seq2seq_batch(
|
||||
src_texts, tgt_texts, max_length=max_length, max_target_length=max_target_length, **kwargs
|
||||
)
|
||||
|
||||
@@ -23,9 +23,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@@ -295,43 +293,3 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
truncation: bool = True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
model_inputs = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
labels_and_decoder_mask = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
|
||||
return model_inputs
|
||||
|
||||
@@ -19,9 +19,7 @@ import os
|
||||
from shutil import copyfile
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ...file_utils import add_start_docstrings, is_sentencepiece_available
|
||||
from ...tokenization_utils import BatchEncoding
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
||||
from ...file_utils import is_sentencepiece_available
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
|
||||
@@ -212,47 +210,3 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
|
||||
if token_ids_1 is None:
|
||||
return len(token_ids_0 + eos) * [0]
|
||||
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
truncation: bool = True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
self.prefix_tokens = []
|
||||
model_inputs = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
# set prefix_tokens for target text
|
||||
self.prefix_tokens = [self.pad_token_id]
|
||||
labels_and_decoder_mask = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
|
||||
self.prefix_tokens = []
|
||||
return model_inputs
|
||||
|
||||
@@ -738,80 +738,3 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
||||
return clean_text
|
||||
else:
|
||||
return text
|
||||
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = "None",
|
||||
truncation=True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
r"""
|
||||
|
||||
Prepare a batch that can be passed directly to an instance of :class:`~transformers.AutoModelForSeq2SeqLM`.
|
||||
|
||||
Args:
|
||||
src_texts: (:obj:`List[str]`):
|
||||
List of documents to summarize or source language texts.
|
||||
tgt_texts: (:obj:`List[str]`, `optional`):
|
||||
List of summaries or target language texts.
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts). If
|
||||
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
|
||||
is required by one of the truncation/padding parameters. If the model has no specific maximum input
|
||||
length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||
max_target_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length of decoder inputs (target language texts or summaries). If left unset or
|
||||
set to :obj:`None`, this will use the max_length value.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||
if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||
sequence lengths greater than the model maximum admissible input size).
|
||||
**kwargs:
|
||||
Additional keyword arguments passed along to :obj:`self.__call__`.
|
||||
|
||||
Returns:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **labels** -- List of token ids for tgt_texts
|
||||
|
||||
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
|
||||
Otherwise, input_ids, attention_mask will be the only keys.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"If your model requires more than input_ids for a typical forward pass, you should implement this method. "
|
||||
"Returned keys should be [input_ids, attention_mask, labels]. See MarianTokenizer or T5Tokenizer for a "
|
||||
"reference implementation."
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ import json
|
||||
import os
|
||||
import warnings
|
||||
from collections import OrderedDict, UserDict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
@@ -1473,68 +1474,6 @@ INIT_TOKENIZER_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
PREPARE_SEQ2SEQ_BATCH_DOCSTRING = """
|
||||
Prepare model inputs for translation. For best performance, translate one sentence at a time.
|
||||
|
||||
Arguments:
|
||||
src_texts (:obj:`List[str]`):
|
||||
List of documents to summarize or source language texts.
|
||||
tgt_texts (:obj:`list`, `optional`):
|
||||
List of summaries or target language texts.
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts) If
|
||||
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
|
||||
is required by one of the truncation/padding parameters. If the model has no specific maximum input
|
||||
length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||
max_target_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set
|
||||
to :obj:`None`, this will use the max_length value.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||
if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||
sequence lengths greater than the model maximum admissible input size).
|
||||
**kwargs:
|
||||
Additional keyword arguments passed along to :obj:`self.__call__`.
|
||||
|
||||
Return:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **labels** -- List of token ids for tgt_texts.
|
||||
|
||||
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
|
||||
Otherwise, input_ids, attention_mask will be the only keys.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
|
||||
class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
"""
|
||||
@@ -3252,3 +3191,113 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
"indexing errors".format(len(ids), self.model_max_length)
|
||||
)
|
||||
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
yield
|
||||
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
max_target_length: Optional[int] = None,
|
||||
padding: str = "longest",
|
||||
return_tensors: str = None,
|
||||
truncation: bool = True,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Prepare model inputs for translation. For best performance, translate one sentence at a time.
|
||||
|
||||
Arguments:
|
||||
src_texts (:obj:`List[str]`):
|
||||
List of documents to summarize or source language texts.
|
||||
tgt_texts (:obj:`list`, `optional`):
|
||||
List of summaries or target language texts.
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts) If
|
||||
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
|
||||
is required by one of the truncation/padding parameters. If the model has no specific maximum input
|
||||
length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||
max_target_length (:obj:`int`, `optional`):
|
||||
Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set
|
||||
to :obj:`None`, this will use the max_length value.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||
Activates and controls padding. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||
single sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Activates and controls truncation. Accepts the following values:
|
||||
|
||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||
if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||
sequence lengths greater than the model maximum admissible input size).
|
||||
**kwargs:
|
||||
Additional keyword arguments passed along to :obj:`self.__call__`.
|
||||
|
||||
Return:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **labels** -- List of token ids for tgt_texts.
|
||||
|
||||
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
|
||||
Otherwise, input_ids, attention_mask will be the only keys.
|
||||
"""
|
||||
# mBART-specific kwargs that should be ignored by other models.
|
||||
kwargs.pop("src_lang", None)
|
||||
kwargs.pop("tgt_lang", None)
|
||||
if max_length is None:
|
||||
max_length = self.model_max_length
|
||||
model_inputs = self(
|
||||
src_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
max_length=max_length,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
if max_target_length is None:
|
||||
max_target_length = max_length
|
||||
with self.as_target_tokenizer():
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
max_length=max_target_length,
|
||||
truncation=truncation,
|
||||
**kwargs,
|
||||
)
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
return model_inputs
|
||||
|
||||
Reference in New Issue
Block a user