Improve truncation_side (#14947)
* Enabling `truncation_side` for Slow and Fast tokenizer. Co-Authored-by: Niels Rogge <48327001+NielsRogge@users.noreply.github.com> * Disable failing tests. * Layout xlm. * assert -> assertEqual. Co-authored-by: Niels Rogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
@@ -1437,6 +1437,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
# to make sure `tokenizer.pad(...)` works correctly
|
||||
model_input_names: List[str] = ["input_ids", "token_type_ids", "attention_mask"]
|
||||
padding_side: str = "right"
|
||||
truncation_side: str = "right"
|
||||
slow_tokenizer_class = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -1514,7 +1515,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
return (
|
||||
f"{'PreTrainedTokenizerFast' if self.is_fast else 'PreTrainedTokenizer'}(name_or_path='{self.name_or_path}', "
|
||||
f"vocab_size={self.vocab_size}, model_max_len={self.model_max_length}, is_fast={self.is_fast}, "
|
||||
f"padding_side='{self.padding_side}', special_tokens={self.special_tokens_map_extended})"
|
||||
f"padding_side='{self.padding_side}', truncation_side='{self.truncation_side}', special_tokens={self.special_tokens_map_extended})"
|
||||
)
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
@@ -3041,8 +3042,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
):
|
||||
if len(ids) > num_tokens_to_remove:
|
||||
window_len = min(len(ids), stride + num_tokens_to_remove)
|
||||
overflowing_tokens = ids[-window_len:]
|
||||
ids = ids[:-num_tokens_to_remove]
|
||||
if self.truncation_side == "left":
|
||||
overflowing_tokens = ids[:window_len]
|
||||
ids = ids[num_tokens_to_remove:]
|
||||
elif self.truncation_side == "right":
|
||||
overflowing_tokens = ids[-window_len:]
|
||||
ids = ids[:-num_tokens_to_remove]
|
||||
else:
|
||||
raise ValueError(f"invalid truncation strategy: {self.truncation_side}, use 'left' or 'right'.")
|
||||
|
||||
else:
|
||||
error_msg = (
|
||||
f"We need to remove {num_tokens_to_remove} to truncate the input "
|
||||
@@ -3063,14 +3071,30 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
)
|
||||
for _ in range(num_tokens_to_remove):
|
||||
if pair_ids is None or len(ids) > len(pair_ids):
|
||||
ids = ids[:-1]
|
||||
if self.truncation_side == "right":
|
||||
ids = ids[:-1]
|
||||
elif self.truncation_side == "left":
|
||||
ids = ids[1:]
|
||||
else:
|
||||
raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
|
||||
else:
|
||||
pair_ids = pair_ids[:-1]
|
||||
if self.truncation_side == "right":
|
||||
pair_ids = pair_ids[:-1]
|
||||
elif self.truncation_side == "left":
|
||||
pair_ids = pair_ids[1:]
|
||||
else:
|
||||
raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
|
||||
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
|
||||
if len(pair_ids) > num_tokens_to_remove:
|
||||
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
|
||||
overflowing_tokens = pair_ids[-window_len:]
|
||||
pair_ids = pair_ids[:-num_tokens_to_remove]
|
||||
if self.truncation_side == "right":
|
||||
overflowing_tokens = pair_ids[-window_len:]
|
||||
pair_ids = pair_ids[:-num_tokens_to_remove]
|
||||
elif self.truncation_side == "left":
|
||||
overflowing_tokens = pair_ids[:window_len]
|
||||
pair_ids = pair_ids[num_tokens_to_remove:]
|
||||
else:
|
||||
raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
|
||||
else:
|
||||
logger.error(
|
||||
f"We need to remove {num_tokens_to_remove} to truncate the input "
|
||||
|
||||
Reference in New Issue
Block a user