Refactor prepare_seq2seq_batch (#9524)
* Add target contextmanager and rework prepare_seq2seq_batch * Fix tests, treat BART and Barthez * Add last tokenizers * Fix test * Set src token before calling the superclass * Remove special behavior for T5 * Remove needless imports * Remove needless asserts
This commit is contained in:
@@ -13,10 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings
|
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..roberta.tokenization_roberta import RobertaTokenizer
|
from ..roberta.tokenization_roberta import RobertaTokenizer
|
||||||
|
|
||||||
@@ -54,45 +50,3 @@ class BartTokenizer(RobertaTokenizer):
|
|||||||
"vocab_file": {m: vocab_url for m in _all_bart_models},
|
"vocab_file": {m: vocab_url for m in _all_bart_models},
|
||||||
"merges_file": {m: merges_url for m in _all_bart_models},
|
"merges_file": {m: merges_url for m in _all_bart_models},
|
||||||
}
|
}
|
||||||
|
|
||||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
|
||||||
def prepare_seq2seq_batch(
|
|
||||||
self,
|
|
||||||
src_texts: List[str],
|
|
||||||
tgt_texts: Optional[List[str]] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
max_target_length: Optional[int] = None,
|
|
||||||
padding: str = "longest",
|
|
||||||
return_tensors: str = None,
|
|
||||||
truncation=True,
|
|
||||||
**kwargs,
|
|
||||||
) -> BatchEncoding:
|
|
||||||
kwargs.pop("src_lang", None)
|
|
||||||
kwargs.pop("tgt_lang", None)
|
|
||||||
if max_length is None:
|
|
||||||
max_length = self.model_max_length
|
|
||||||
model_inputs: BatchEncoding = self(
|
|
||||||
src_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
max_length=max_length,
|
|
||||||
padding=padding,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if tgt_texts is None:
|
|
||||||
return model_inputs
|
|
||||||
# Process tgt_texts
|
|
||||||
if max_target_length is None:
|
|
||||||
max_target_length = max_length
|
|
||||||
labels = self(
|
|
||||||
tgt_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
padding=padding,
|
|
||||||
max_length=max_target_length,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)["input_ids"]
|
|
||||||
model_inputs["labels"] = labels
|
|
||||||
return model_inputs
|
|
||||||
|
|||||||
@@ -13,10 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings
|
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
|
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
|
||||||
from .tokenization_bart import BartTokenizer
|
from .tokenization_bart import BartTokenizer
|
||||||
@@ -49,43 +45,3 @@ class BartTokenizerFast(RobertaTokenizerFast):
|
|||||||
"tokenizer_file": {m: tokenizer_url for m in _all_bart_models},
|
"tokenizer_file": {m: tokenizer_url for m in _all_bart_models},
|
||||||
}
|
}
|
||||||
slow_tokenizer_class = BartTokenizer
|
slow_tokenizer_class = BartTokenizer
|
||||||
|
|
||||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
|
||||||
def prepare_seq2seq_batch(
|
|
||||||
self,
|
|
||||||
src_texts: List[str],
|
|
||||||
tgt_texts: Optional[List[str]] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
max_target_length: Optional[int] = None,
|
|
||||||
padding: str = "longest",
|
|
||||||
return_tensors: Optional[str] = None,
|
|
||||||
truncation=True,
|
|
||||||
**kwargs,
|
|
||||||
) -> BatchEncoding:
|
|
||||||
if max_length is None:
|
|
||||||
max_length = self.model_max_length
|
|
||||||
model_inputs: BatchEncoding = self(
|
|
||||||
src_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
max_length=max_length,
|
|
||||||
padding=padding,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if tgt_texts is None:
|
|
||||||
return model_inputs
|
|
||||||
# Process tgt_texts
|
|
||||||
if max_target_length is None:
|
|
||||||
max_target_length = max_length
|
|
||||||
labels = self(
|
|
||||||
tgt_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
padding=padding,
|
|
||||||
max_length=max_target_length,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)["input_ids"]
|
|
||||||
model_inputs["labels"] = labels
|
|
||||||
return model_inputs
|
|
||||||
|
|||||||
@@ -21,9 +21,7 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings
|
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -264,45 +262,3 @@ class BarthezTokenizer(PreTrainedTokenizer):
|
|||||||
copyfile(self.vocab_file, out_vocab_file)
|
copyfile(self.vocab_file, out_vocab_file)
|
||||||
|
|
||||||
return (out_vocab_file,)
|
return (out_vocab_file,)
|
||||||
|
|
||||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
|
||||||
def prepare_seq2seq_batch(
|
|
||||||
self,
|
|
||||||
src_texts: List[str],
|
|
||||||
tgt_texts: Optional[List[str]] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
max_target_length: Optional[int] = None,
|
|
||||||
padding: str = "longest",
|
|
||||||
return_tensors: str = "None",
|
|
||||||
truncation=True,
|
|
||||||
**kwargs,
|
|
||||||
) -> BatchEncoding:
|
|
||||||
kwargs.pop("src_lang", None)
|
|
||||||
kwargs.pop("tgt_lang", None)
|
|
||||||
if max_length is None:
|
|
||||||
max_length = self.model_max_length
|
|
||||||
model_inputs: BatchEncoding = self(
|
|
||||||
src_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
max_length=max_length,
|
|
||||||
padding=padding,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if tgt_texts is None:
|
|
||||||
return model_inputs
|
|
||||||
# Process tgt_texts
|
|
||||||
if max_target_length is None:
|
|
||||||
max_target_length = max_length
|
|
||||||
labels = self(
|
|
||||||
tgt_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
padding=padding,
|
|
||||||
max_length=max_target_length,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)["input_ids"]
|
|
||||||
model_inputs["labels"] = labels
|
|
||||||
return model_inputs
|
|
||||||
|
|||||||
@@ -19,8 +19,7 @@ import os
|
|||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings, is_sentencepiece_available
|
from ...file_utils import is_sentencepiece_available
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
@@ -228,45 +227,3 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
copyfile(self.vocab_file, out_vocab_file)
|
copyfile(self.vocab_file, out_vocab_file)
|
||||||
|
|
||||||
return (out_vocab_file,)
|
return (out_vocab_file,)
|
||||||
|
|
||||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
|
||||||
def prepare_seq2seq_batch(
|
|
||||||
self,
|
|
||||||
src_texts: List[str],
|
|
||||||
tgt_texts: Optional[List[str]] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
max_target_length: Optional[int] = None,
|
|
||||||
padding: str = "longest",
|
|
||||||
return_tensors: str = "None",
|
|
||||||
truncation=True,
|
|
||||||
**kwargs,
|
|
||||||
) -> BatchEncoding:
|
|
||||||
kwargs.pop("src_lang", None)
|
|
||||||
kwargs.pop("tgt_lang", None)
|
|
||||||
if max_length is None:
|
|
||||||
max_length = self.model_max_length
|
|
||||||
model_inputs: BatchEncoding = self(
|
|
||||||
src_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
max_length=max_length,
|
|
||||||
padding=padding,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if tgt_texts is None:
|
|
||||||
return model_inputs
|
|
||||||
# Process tgt_texts
|
|
||||||
if max_target_length is None:
|
|
||||||
max_target_length = max_length
|
|
||||||
labels = self(
|
|
||||||
tgt_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
padding=padding,
|
|
||||||
max_length=max_target_length,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)["input_ids"]
|
|
||||||
model_inputs["labels"] = labels
|
|
||||||
return model_inputs
|
|
||||||
|
|||||||
@@ -23,9 +23,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import sacremoses as sm
|
import sacremoses as sm
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
|
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -484,40 +482,6 @@ class FSMTTokenizer(PreTrainedTokenizer):
|
|||||||
return len(token_ids_0 + sep) * [0]
|
return len(token_ids_0 + sep) * [0]
|
||||||
return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||||
|
|
||||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
|
||||||
def prepare_seq2seq_batch(
|
|
||||||
self,
|
|
||||||
src_texts: List[str],
|
|
||||||
tgt_texts: Optional[List[str]] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
max_target_length: Optional[int] = None,
|
|
||||||
return_tensors: Optional[str] = None,
|
|
||||||
truncation=True,
|
|
||||||
padding="longest",
|
|
||||||
**unused,
|
|
||||||
) -> BatchEncoding:
|
|
||||||
if type(src_texts) is not list:
|
|
||||||
raise ValueError("src_texts is expected to be a list")
|
|
||||||
if "" in src_texts:
|
|
||||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
|
||||||
|
|
||||||
tokenizer_kwargs = dict(
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
max_length=max_length,
|
|
||||||
truncation=truncation,
|
|
||||||
padding=padding,
|
|
||||||
)
|
|
||||||
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
|
||||||
|
|
||||||
if tgt_texts is None:
|
|
||||||
return model_inputs
|
|
||||||
if max_target_length is not None:
|
|
||||||
tokenizer_kwargs["max_length"] = max_target_length
|
|
||||||
|
|
||||||
model_inputs["labels"] = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
if not os.path.isdir(save_directory):
|
if not os.path.isdir(save_directory):
|
||||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||||
|
|||||||
@@ -15,15 +15,14 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import sentencepiece
|
import sentencepiece
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
|
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
|
||||||
|
|
||||||
|
|
||||||
vocab_files_names = {
|
vocab_files_names = {
|
||||||
@@ -182,40 +181,15 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||||
|
|
||||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
@contextmanager
|
||||||
def prepare_seq2seq_batch(
|
def as_target_tokenizer(self):
|
||||||
self,
|
"""
|
||||||
src_texts: List[str],
|
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||||
tgt_texts: Optional[List[str]] = None,
|
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||||
max_length: Optional[int] = None,
|
"""
|
||||||
max_target_length: Optional[int] = None,
|
|
||||||
return_tensors: Optional[str] = None,
|
|
||||||
truncation=True,
|
|
||||||
padding="longest",
|
|
||||||
**unused,
|
|
||||||
) -> BatchEncoding:
|
|
||||||
if "" in src_texts:
|
|
||||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
|
||||||
self.current_spm = self.spm_source
|
|
||||||
src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much
|
|
||||||
tokenizer_kwargs = dict(
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
max_length=max_length,
|
|
||||||
truncation=truncation,
|
|
||||||
padding=padding,
|
|
||||||
)
|
|
||||||
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
|
||||||
|
|
||||||
if tgt_texts is None:
|
|
||||||
return model_inputs
|
|
||||||
if max_target_length is not None:
|
|
||||||
tokenizer_kwargs["max_length"] = max_target_length
|
|
||||||
|
|
||||||
self.current_spm = self.spm_target
|
self.current_spm = self.spm_target
|
||||||
model_inputs["labels"] = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
yield
|
||||||
self.current_spm = self.spm_source
|
self.current_spm = self.spm_source
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self) -> int:
|
def vocab_size(self) -> int:
|
||||||
|
|||||||
@@ -13,11 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings
|
|
||||||
from ...tokenization_utils import BatchEncoding
|
from ...tokenization_utils import BatchEncoding
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
|
from ..xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||||
|
|
||||||
@@ -172,52 +171,28 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||||
|
|
||||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
|
||||||
def prepare_seq2seq_batch(
|
def prepare_seq2seq_batch(
|
||||||
self,
|
self,
|
||||||
src_texts: List[str],
|
src_texts: List[str],
|
||||||
src_lang: str = "en_XX",
|
src_lang: str = "en_XX",
|
||||||
tgt_texts: Optional[List[str]] = None,
|
tgt_texts: Optional[List[str]] = None,
|
||||||
tgt_lang: str = "ro_RO",
|
tgt_lang: str = "ro_RO",
|
||||||
max_length: Optional[int] = None,
|
|
||||||
max_target_length: Optional[int] = None,
|
|
||||||
truncation: bool = True,
|
|
||||||
padding: str = "longest",
|
|
||||||
return_tensors: Optional[str] = None,
|
|
||||||
add_prefix_space: bool = False, # ignored
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
if max_length is None:
|
self.src_lang = src_lang
|
||||||
max_length = self.model_max_length
|
self.tgt_lang = tgt_lang
|
||||||
self.set_src_lang_special_tokens(src_lang)
|
self.set_src_lang_special_tokens(self.src_lang)
|
||||||
model_inputs: BatchEncoding = self(
|
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||||
src_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
max_length=max_length,
|
|
||||||
padding=padding,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if tgt_texts is None:
|
|
||||||
return model_inputs
|
|
||||||
# Process tgt_texts
|
|
||||||
if max_target_length is None:
|
|
||||||
max_target_length = max_length
|
|
||||||
self.set_tgt_lang_special_tokens(tgt_lang)
|
|
||||||
|
|
||||||
labels = self(
|
@contextmanager
|
||||||
tgt_texts,
|
def as_target_tokenizer(self):
|
||||||
add_special_tokens=True,
|
"""
|
||||||
return_tensors=return_tensors,
|
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||||
padding=padding,
|
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||||
max_length=max_target_length,
|
"""
|
||||||
truncation=True,
|
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||||
**kwargs,
|
yield
|
||||||
)["input_ids"]
|
self.set_src_lang_special_tokens(self.src_lang)
|
||||||
model_inputs["labels"] = labels
|
|
||||||
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
||||||
|
|||||||
@@ -13,13 +13,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from tokenizers import processors
|
from tokenizers import processors
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings, is_sentencepiece_available
|
from ...file_utils import is_sentencepiece_available
|
||||||
from ...tokenization_utils import BatchEncoding
|
from ...tokenization_utils import BatchEncoding
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
|
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
|
||||||
|
|
||||||
@@ -171,51 +171,28 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
|||||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||||
|
|
||||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
|
||||||
def prepare_seq2seq_batch(
|
def prepare_seq2seq_batch(
|
||||||
self,
|
self,
|
||||||
src_texts: List[str],
|
src_texts: List[str],
|
||||||
src_lang: str = "en_XX",
|
src_lang: str = "en_XX",
|
||||||
tgt_texts: Optional[List[str]] = None,
|
tgt_texts: Optional[List[str]] = None,
|
||||||
tgt_lang: str = "ro_RO",
|
tgt_lang: str = "ro_RO",
|
||||||
max_length: Optional[int] = None,
|
|
||||||
max_target_length: Optional[int] = None,
|
|
||||||
truncation: bool = True,
|
|
||||||
padding: str = "longest",
|
|
||||||
return_tensors: str = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
if max_length is None:
|
self.src_lang = src_lang
|
||||||
max_length = self.model_max_length
|
self.tgt_lang = tgt_lang
|
||||||
self.set_src_lang_special_tokens(src_lang)
|
self.set_src_lang_special_tokens(self.src_lang)
|
||||||
model_inputs: BatchEncoding = self(
|
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||||
src_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
max_length=max_length,
|
|
||||||
padding=padding,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if tgt_texts is None:
|
|
||||||
return model_inputs
|
|
||||||
# Process tgt_texts
|
|
||||||
if max_target_length is None:
|
|
||||||
max_target_length = max_length
|
|
||||||
self.set_tgt_lang_special_tokens(tgt_lang)
|
|
||||||
|
|
||||||
labels = self(
|
@contextmanager
|
||||||
tgt_texts,
|
def as_target_tokenizer(self):
|
||||||
add_special_tokens=True,
|
"""
|
||||||
return_tensors=return_tensors,
|
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||||
padding=padding,
|
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||||
max_length=max_target_length,
|
"""
|
||||||
truncation=True,
|
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||||
**kwargs,
|
yield
|
||||||
)["input_ids"]
|
self.set_src_lang_special_tokens(self.src_lang)
|
||||||
model_inputs["labels"] = labels
|
|
||||||
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
||||||
|
|||||||
@@ -18,9 +18,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings
|
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -250,36 +248,6 @@ class PegasusTokenizer(PreTrainedTokenizer):
|
|||||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||||
|
|
||||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
|
||||||
def prepare_seq2seq_batch(
|
|
||||||
self,
|
|
||||||
src_texts: List[str],
|
|
||||||
tgt_texts: Optional[List[str]] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
max_target_length: Optional[int] = None,
|
|
||||||
return_tensors: str = None,
|
|
||||||
truncation=True,
|
|
||||||
padding="longest",
|
|
||||||
**unused,
|
|
||||||
) -> BatchEncoding:
|
|
||||||
if "" in src_texts:
|
|
||||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
|
||||||
tokenizer_kwargs = dict(
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
max_length=max_length,
|
|
||||||
truncation=truncation,
|
|
||||||
padding=padding,
|
|
||||||
)
|
|
||||||
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
|
||||||
if tgt_texts is None:
|
|
||||||
return model_inputs
|
|
||||||
if max_target_length is not None:
|
|
||||||
tokenizer_kwargs["max_length"] = max_target_length
|
|
||||||
labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
|
||||||
model_inputs["labels"] = labels
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
if not os.path.isdir(save_directory):
|
if not os.path.isdir(save_directory):
|
||||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||||
|
|||||||
@@ -19,8 +19,7 @@ import os
|
|||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings, is_sentencepiece_available
|
from ...file_utils import is_sentencepiece_available
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
@@ -188,36 +187,6 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||||
|
|
||||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
|
||||||
def prepare_seq2seq_batch(
|
|
||||||
self,
|
|
||||||
src_texts: List[str],
|
|
||||||
tgt_texts: Optional[List[str]] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
max_target_length: Optional[int] = None,
|
|
||||||
return_tensors: str = None,
|
|
||||||
truncation=True,
|
|
||||||
padding="longest",
|
|
||||||
**unused,
|
|
||||||
) -> BatchEncoding:
|
|
||||||
if "" in src_texts:
|
|
||||||
raise ValueError(f"found empty string in src_texts: {src_texts}")
|
|
||||||
tokenizer_kwargs = dict(
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
max_length=max_length,
|
|
||||||
truncation=truncation,
|
|
||||||
padding=padding,
|
|
||||||
)
|
|
||||||
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
|
|
||||||
if tgt_texts is None:
|
|
||||||
return model_inputs
|
|
||||||
if max_target_length is not None:
|
|
||||||
tokenizer_kwargs["max_length"] = max_target_length
|
|
||||||
labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
|
|
||||||
model_inputs["labels"] = labels
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
if not os.path.isdir(save_directory):
|
if not os.path.isdir(save_directory):
|
||||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||||
|
|||||||
@@ -17,9 +17,7 @@ import collections
|
|||||||
import os
|
import os
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
|
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer
|
from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer
|
||||||
|
|
||||||
@@ -288,43 +286,3 @@ class ProphetNetTokenizer(PreTrainedTokenizer):
|
|||||||
return token_ids_0 + [self.sep_token_id]
|
return token_ids_0 + [self.sep_token_id]
|
||||||
sep = [self.sep_token_id]
|
sep = [self.sep_token_id]
|
||||||
return token_ids_0 + sep + token_ids_1 + sep
|
return token_ids_0 + sep + token_ids_1 + sep
|
||||||
|
|
||||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
|
||||||
def prepare_seq2seq_batch(
|
|
||||||
self,
|
|
||||||
src_texts: List[str],
|
|
||||||
tgt_texts: Optional[List[str]] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
max_target_length: Optional[int] = None,
|
|
||||||
padding: str = "longest",
|
|
||||||
return_tensors: str = None,
|
|
||||||
truncation: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
) -> BatchEncoding:
|
|
||||||
if max_length is None:
|
|
||||||
max_length = self.model_max_length
|
|
||||||
model_inputs = self(
|
|
||||||
src_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
max_length=max_length,
|
|
||||||
padding=padding,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if tgt_texts is None:
|
|
||||||
return model_inputs
|
|
||||||
# Process tgt_texts
|
|
||||||
if max_target_length is None:
|
|
||||||
max_target_length = max_length
|
|
||||||
labels_and_decoder_mask = self(
|
|
||||||
tgt_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
padding=padding,
|
|
||||||
max_length=max_target_length,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
|
|
||||||
return model_inputs
|
|
||||||
|
|||||||
@@ -16,8 +16,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings
|
from ...tokenization_utils_base import BatchEncoding
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_rag import RagConfig
|
from .configuration_rag import RagConfig
|
||||||
|
|
||||||
@@ -63,42 +62,18 @@ class RagTokenizer:
|
|||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
return self.generator.batch_decode(*args, **kwargs)
|
return self.generator.batch_decode(*args, **kwargs)
|
||||||
|
|
||||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
|
||||||
def prepare_seq2seq_batch(
|
def prepare_seq2seq_batch(
|
||||||
self,
|
self,
|
||||||
src_texts: List[str],
|
src_texts: List[str],
|
||||||
tgt_texts: Optional[List[str]] = None,
|
tgt_texts: Optional[List[str]] = None,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
max_target_length: Optional[int] = None,
|
max_target_length: Optional[int] = None,
|
||||||
padding: str = "longest",
|
|
||||||
return_tensors: str = None,
|
|
||||||
truncation=True,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
if max_length is None:
|
if max_length is None:
|
||||||
max_length = self.question_encoder.model_max_length
|
max_length = self.question_encoder.model_max_length
|
||||||
model_inputs: BatchEncoding = self.question_encoder(
|
|
||||||
src_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
max_length=max_length,
|
|
||||||
padding=padding,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if tgt_texts is None:
|
|
||||||
return model_inputs
|
|
||||||
# Process tgt_texts
|
|
||||||
if max_target_length is None:
|
if max_target_length is None:
|
||||||
max_target_length = self.generator.model_max_length
|
max_target_length = self.generator.model_max_length
|
||||||
labels = self.generator(
|
return super().prepare_seq2seq_batch(
|
||||||
tgt_texts,
|
src_texts, tgt_texts, max_length=max_length, max_target_length=max_target_length, **kwargs
|
||||||
add_special_tokens=True,
|
)
|
||||||
return_tensors=return_tensors,
|
|
||||||
padding=padding,
|
|
||||||
max_length=max_target_length,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)["input_ids"]
|
|
||||||
model_inputs["labels"] = labels
|
|
||||||
return model_inputs
|
|
||||||
|
|||||||
@@ -23,9 +23,7 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
|
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -295,43 +293,3 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||||||
copyfile(self.vocab_file, out_vocab_file)
|
copyfile(self.vocab_file, out_vocab_file)
|
||||||
|
|
||||||
return (out_vocab_file,)
|
return (out_vocab_file,)
|
||||||
|
|
||||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
|
||||||
def prepare_seq2seq_batch(
|
|
||||||
self,
|
|
||||||
src_texts: List[str],
|
|
||||||
tgt_texts: Optional[List[str]] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
max_target_length: Optional[int] = None,
|
|
||||||
padding: str = "longest",
|
|
||||||
return_tensors: str = None,
|
|
||||||
truncation: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
) -> BatchEncoding:
|
|
||||||
if max_length is None:
|
|
||||||
max_length = self.model_max_length
|
|
||||||
model_inputs = self(
|
|
||||||
src_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
max_length=max_length,
|
|
||||||
padding=padding,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if tgt_texts is None:
|
|
||||||
return model_inputs
|
|
||||||
# Process tgt_texts
|
|
||||||
if max_target_length is None:
|
|
||||||
max_target_length = max_length
|
|
||||||
labels_and_decoder_mask = self(
|
|
||||||
tgt_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
padding=padding,
|
|
||||||
max_length=max_target_length,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
|
|
||||||
return model_inputs
|
|
||||||
|
|||||||
@@ -19,9 +19,7 @@ import os
|
|||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings, is_sentencepiece_available
|
from ...file_utils import is_sentencepiece_available
|
||||||
from ...tokenization_utils import BatchEncoding
|
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
|
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
@@ -212,47 +210,3 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
|
|||||||
if token_ids_1 is None:
|
if token_ids_1 is None:
|
||||||
return len(token_ids_0 + eos) * [0]
|
return len(token_ids_0 + eos) * [0]
|
||||||
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
||||||
|
|
||||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
|
||||||
def prepare_seq2seq_batch(
|
|
||||||
self,
|
|
||||||
src_texts: List[str],
|
|
||||||
tgt_texts: Optional[List[str]] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
max_target_length: Optional[int] = None,
|
|
||||||
padding: str = "longest",
|
|
||||||
return_tensors: str = None,
|
|
||||||
truncation: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
) -> BatchEncoding:
|
|
||||||
if max_length is None:
|
|
||||||
max_length = self.model_max_length
|
|
||||||
self.prefix_tokens = []
|
|
||||||
model_inputs = self(
|
|
||||||
src_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
max_length=max_length,
|
|
||||||
padding=padding,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if tgt_texts is None:
|
|
||||||
return model_inputs
|
|
||||||
# Process tgt_texts
|
|
||||||
if max_target_length is None:
|
|
||||||
max_target_length = max_length
|
|
||||||
# set prefix_tokens for target text
|
|
||||||
self.prefix_tokens = [self.pad_token_id]
|
|
||||||
labels_and_decoder_mask = self(
|
|
||||||
tgt_texts,
|
|
||||||
add_special_tokens=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
padding=padding,
|
|
||||||
max_length=max_target_length,
|
|
||||||
truncation=truncation,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
|
|
||||||
self.prefix_tokens = []
|
|
||||||
return model_inputs
|
|
||||||
|
|||||||
@@ -738,80 +738,3 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
return clean_text
|
return clean_text
|
||||||
else:
|
else:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def prepare_seq2seq_batch(
|
|
||||||
self,
|
|
||||||
src_texts: List[str],
|
|
||||||
tgt_texts: Optional[List[str]] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
max_target_length: Optional[int] = None,
|
|
||||||
padding: str = "longest",
|
|
||||||
return_tensors: str = "None",
|
|
||||||
truncation=True,
|
|
||||||
**kwargs,
|
|
||||||
) -> BatchEncoding:
|
|
||||||
r"""
|
|
||||||
|
|
||||||
Prepare a batch that can be passed directly to an instance of :class:`~transformers.AutoModelForSeq2SeqLM`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
src_texts: (:obj:`List[str]`):
|
|
||||||
List of documents to summarize or source language texts.
|
|
||||||
tgt_texts: (:obj:`List[str]`, `optional`):
|
|
||||||
List of summaries or target language texts.
|
|
||||||
max_length (:obj:`int`, `optional`):
|
|
||||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts). If
|
|
||||||
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
|
|
||||||
is required by one of the truncation/padding parameters. If the model has no specific maximum input
|
|
||||||
length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
|
||||||
max_target_length (:obj:`int`, `optional`):
|
|
||||||
Controls the maximum length of decoder inputs (target language texts or summaries). If left unset or
|
|
||||||
set to :obj:`None`, this will use the max_length value.
|
|
||||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
|
||||||
Activates and controls padding. Accepts the following values:
|
|
||||||
|
|
||||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
|
||||||
single sequence if provided).
|
|
||||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
|
||||||
maximum acceptable input length for the model if that argument is not provided.
|
|
||||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
|
||||||
different lengths).
|
|
||||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
|
||||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
|
||||||
|
|
||||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
|
||||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
|
||||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
|
||||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
|
||||||
Activates and controls truncation. Accepts the following values:
|
|
||||||
|
|
||||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
|
||||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
|
||||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
|
||||||
if a pair of sequences (or a batch of pairs) is provided.
|
|
||||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
|
||||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
|
||||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
|
||||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
|
||||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
|
||||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
|
||||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
|
||||||
sequence lengths greater than the model maximum admissible input size).
|
|
||||||
**kwargs:
|
|
||||||
Additional keyword arguments passed along to :obj:`self.__call__`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
|
||||||
|
|
||||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
|
||||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
|
||||||
- **labels** -- List of token ids for tgt_texts
|
|
||||||
|
|
||||||
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
|
|
||||||
Otherwise, input_ids, attention_mask will be the only keys.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(
|
|
||||||
"If your model requires more than input_ids for a typical forward pass, you should implement this method. "
|
|
||||||
"Returned keys should be [input_ids, attention_mask, labels]. See MarianTokenizer or T5Tokenizer for a "
|
|
||||||
"reference implementation."
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict, UserDict
|
from collections import OrderedDict, UserDict
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||||
@@ -1473,68 +1474,6 @@ INIT_TOKENIZER_DOCSTRING = r"""
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
PREPARE_SEQ2SEQ_BATCH_DOCSTRING = """
|
|
||||||
Prepare model inputs for translation. For best performance, translate one sentence at a time.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
src_texts (:obj:`List[str]`):
|
|
||||||
List of documents to summarize or source language texts.
|
|
||||||
tgt_texts (:obj:`list`, `optional`):
|
|
||||||
List of summaries or target language texts.
|
|
||||||
max_length (:obj:`int`, `optional`):
|
|
||||||
Controls the maximum length for encoder inputs (documents to summarize or source language texts) If
|
|
||||||
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
|
|
||||||
is required by one of the truncation/padding parameters. If the model has no specific maximum input
|
|
||||||
length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
|
||||||
max_target_length (:obj:`int`, `optional`):
|
|
||||||
Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set
|
|
||||||
to :obj:`None`, this will use the max_length value.
|
|
||||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
|
||||||
Activates and controls padding. Accepts the following values:
|
|
||||||
|
|
||||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
|
||||||
single sequence if provided).
|
|
||||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
|
||||||
maximum acceptable input length for the model if that argument is not provided.
|
|
||||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
|
||||||
different lengths).
|
|
||||||
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
|
||||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
|
||||||
|
|
||||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
|
||||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
|
||||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
|
||||||
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
|
||||||
Activates and controls truncation. Accepts the following values:
|
|
||||||
|
|
||||||
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
|
||||||
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
|
||||||
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
|
||||||
if a pair of sequences (or a batch of pairs) is provided.
|
|
||||||
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
|
||||||
the maximum acceptable input length for the model if that argument is not provided. This will only
|
|
||||||
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
|
||||||
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
|
||||||
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
|
||||||
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
|
||||||
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
|
||||||
sequence lengths greater than the model maximum admissible input size).
|
|
||||||
**kwargs:
|
|
||||||
Additional keyword arguments passed along to :obj:`self.__call__`.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
|
||||||
|
|
||||||
- **input_ids** -- List of token ids to be fed to the encoder.
|
|
||||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
|
||||||
- **labels** -- List of token ids for tgt_texts.
|
|
||||||
|
|
||||||
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
|
|
||||||
Otherwise, input_ids, attention_mask will be the only keys.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
|
@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
|
||||||
class PreTrainedTokenizerBase(SpecialTokensMixin):
|
class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||||
"""
|
"""
|
||||||
@@ -3252,3 +3191,113 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||||||
"indexing errors".format(len(ids), self.model_max_length)
|
"indexing errors".format(len(ids), self.model_max_length)
|
||||||
)
|
)
|
||||||
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
|
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def as_target_tokenizer(self):
|
||||||
|
"""
|
||||||
|
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||||
|
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||||
|
"""
|
||||||
|
yield
|
||||||
|
|
||||||
|
def prepare_seq2seq_batch(
|
||||||
|
self,
|
||||||
|
src_texts: List[str],
|
||||||
|
tgt_texts: Optional[List[str]] = None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
max_target_length: Optional[int] = None,
|
||||||
|
padding: str = "longest",
|
||||||
|
return_tensors: str = None,
|
||||||
|
truncation: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> BatchEncoding:
|
||||||
|
"""
|
||||||
|
Prepare model inputs for translation. For best performance, translate one sentence at a time.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
src_texts (:obj:`List[str]`):
|
||||||
|
List of documents to summarize or source language texts.
|
||||||
|
tgt_texts (:obj:`list`, `optional`):
|
||||||
|
List of summaries or target language texts.
|
||||||
|
max_length (:obj:`int`, `optional`):
|
||||||
|
Controls the maximum length for encoder inputs (documents to summarize or source language texts) If
|
||||||
|
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
|
||||||
|
is required by one of the truncation/padding parameters. If the model has no specific maximum input
|
||||||
|
length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||||
|
max_target_length (:obj:`int`, `optional`):
|
||||||
|
Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set
|
||||||
|
to :obj:`None`, this will use the max_length value.
|
||||||
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||||
|
Activates and controls padding. Accepts the following values:
|
||||||
|
|
||||||
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||||
|
single sequence if provided).
|
||||||
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||||
|
maximum acceptable input length for the model if that argument is not provided.
|
||||||
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||||
|
different lengths).
|
||||||
|
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
||||||
|
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||||
|
|
||||||
|
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||||
|
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||||
|
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||||
|
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
|
||||||
|
Activates and controls truncation. Accepts the following values:
|
||||||
|
|
||||||
|
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
|
||||||
|
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||||
|
provided. This will truncate token by token, removing a token from the longest sequence in the pair
|
||||||
|
if a pair of sequences (or a batch of pairs) is provided.
|
||||||
|
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
|
||||||
|
the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||||
|
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||||
|
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
|
||||||
|
to the maximum acceptable input length for the model if that argument is not provided. This will only
|
||||||
|
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
|
||||||
|
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||||
|
sequence lengths greater than the model maximum admissible input size).
|
||||||
|
**kwargs:
|
||||||
|
Additional keyword arguments passed along to :obj:`self.__call__`.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||||
|
|
||||||
|
- **input_ids** -- List of token ids to be fed to the encoder.
|
||||||
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||||
|
- **labels** -- List of token ids for tgt_texts.
|
||||||
|
|
||||||
|
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
|
||||||
|
Otherwise, input_ids, attention_mask will be the only keys.
|
||||||
|
"""
|
||||||
|
# mBART-specific kwargs that should be ignored by other models.
|
||||||
|
kwargs.pop("src_lang", None)
|
||||||
|
kwargs.pop("tgt_lang", None)
|
||||||
|
if max_length is None:
|
||||||
|
max_length = self.model_max_length
|
||||||
|
model_inputs = self(
|
||||||
|
src_texts,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
max_length=max_length,
|
||||||
|
padding=padding,
|
||||||
|
truncation=truncation,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if tgt_texts is None:
|
||||||
|
return model_inputs
|
||||||
|
# Process tgt_texts
|
||||||
|
if max_target_length is None:
|
||||||
|
max_target_length = max_length
|
||||||
|
with self.as_target_tokenizer():
|
||||||
|
labels = self(
|
||||||
|
tgt_texts,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
padding=padding,
|
||||||
|
max_length=max_target_length,
|
||||||
|
truncation=truncation,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
model_inputs["labels"] = labels["input_ids"]
|
||||||
|
return model_inputs
|
||||||
|
|||||||
@@ -508,12 +508,6 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
|
|||||||
def test_batch_generation_en_ROMANCE_multi(self):
|
def test_batch_generation_en_ROMANCE_multi(self):
|
||||||
self._assert_generated_batch_equal_expected()
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
def test_tokenizer_handles_empty(self):
|
|
||||||
normalized = self.tokenizer.normalize("")
|
|
||||||
self.assertIsInstance(normalized, str)
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
self.tokenizer.prepare_seq2seq_batch([""], return_tensors="pt")
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_pipeline(self):
|
def test_pipeline(self):
|
||||||
device = 0 if torch_device == "cuda" else -1
|
device = 0 if torch_device == "cuda" else -1
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ class TokenizerTesterMixin:
|
|||||||
from_pretrained_kwargs = None
|
from_pretrained_kwargs = None
|
||||||
from_pretrained_filter = None
|
from_pretrained_filter = None
|
||||||
from_pretrained_vocab_key = "vocab_file"
|
from_pretrained_vocab_key = "vocab_file"
|
||||||
|
test_seq2seq = True
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
|
# Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
|
||||||
@@ -1799,10 +1800,11 @@ class TokenizerTesterMixin:
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_prepare_seq2seq_batch(self):
|
def test_prepare_seq2seq_batch(self):
|
||||||
|
if not self.test_seq2seq:
|
||||||
|
return
|
||||||
|
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
if not hasattr(tokenizer, "prepare_seq2seq_batch"):
|
|
||||||
return
|
|
||||||
# Longer text that will definitely require truncation.
|
# Longer text that will definitely require truncation.
|
||||||
src_text = [
|
src_text = [
|
||||||
" UN Chief Says There Is No Military Solution in Syria",
|
" UN Chief Says There Is No Military Solution in Syria",
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
tokenizer_class = CTRLTokenizer
|
tokenizer_class = CTRLTokenizer
|
||||||
test_rust_tokenizer = False
|
test_rust_tokenizer = False
|
||||||
|
test_seq2seq = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
rust_tokenizer_class = GPT2TokenizerFast
|
rust_tokenizer_class = GPT2TokenizerFast
|
||||||
test_rust_tokenizer = True
|
test_rust_tokenizer = True
|
||||||
from_pretrained_kwargs = {"add_prefix_space": True}
|
from_pretrained_kwargs = {"add_prefix_space": True}
|
||||||
|
test_seq2seq = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
tokenizer_class = OpenAIGPTTokenizer
|
tokenizer_class = OpenAIGPTTokenizer
|
||||||
rust_tokenizer_class = OpenAIGPTTokenizerFast
|
rust_tokenizer_class = OpenAIGPTTokenizerFast
|
||||||
test_rust_tokenizer = True
|
test_rust_tokenizer = True
|
||||||
|
test_seq2seq = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
tokenizer_class = ReformerTokenizer
|
tokenizer_class = ReformerTokenizer
|
||||||
rust_tokenizer_class = ReformerTokenizerFast
|
rust_tokenizer_class = ReformerTokenizerFast
|
||||||
test_rust_tokenizer = True
|
test_rust_tokenizer = True
|
||||||
|
test_seq2seq = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
test_rust_tokenizer = False
|
test_rust_tokenizer = False
|
||||||
space_between_special_tokens = True
|
space_between_special_tokens = True
|
||||||
from_pretrained_filter = filter_non_english
|
from_pretrained_filter = filter_non_english
|
||||||
|
test_seq2seq = False
|
||||||
|
|
||||||
def get_table(
|
def get_table(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
tokenizer_class = TransfoXLTokenizer
|
tokenizer_class = TransfoXLTokenizer
|
||||||
test_rust_tokenizer = False
|
test_rust_tokenizer = False
|
||||||
|
test_seq2seq = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|||||||
Reference in New Issue
Block a user