[Tokenizer Utils Base] Make pad function more flexible (#9928)
* change tokenizer requirement * split line * Correct typo from list to str * improve style * make other function pretty as well * add comment * correct typo * add new test * pass tests for tok without padding token * Apply suggestions from code review
This commit is contained in:
committed by
GitHub
parent
d1b14c9b54
commit
538b3b4607
@@ -1492,7 +1492,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
pretrained_vocab_files_map: Dict[str, Dict[str, str]] = {}
|
||||
pretrained_init_configuration: Dict[str, Dict[str, Any]] = {}
|
||||
max_model_input_sizes: Dict[str, Optional[int]] = {}
|
||||
model_input_names: List[str] = ["token_type_ids", "attention_mask"]
|
||||
|
||||
# first name has to correspond to main model input name
|
||||
# to make sure `tokenizer.pad(...)` works correctly
|
||||
model_input_names: List[str] = ["input_ids", "token_type_ids", "attention_mask"]
|
||||
padding_side: str = "right"
|
||||
slow_tokenizer_class = None
|
||||
|
||||
@@ -2633,13 +2636,16 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
|
||||
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
|
||||
|
||||
assert "input_ids" in encoded_inputs, (
|
||||
"You should supply an encoding or a list of encodings to this method. "
|
||||
"An encoding is the output of one the encoding methods of the tokenizer, i.e. "
|
||||
"__call__/encode_plus/batch_encode_plus. "
|
||||
)
|
||||
# The model's main input name, usually `input_ids`, has be passed for padding
|
||||
if self.model_input_names[0] not in encoded_inputs:
|
||||
raise ValueError(
|
||||
"You should supply an encoding or a list of encodings to this method"
|
||||
f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
|
||||
)
|
||||
|
||||
if not encoded_inputs["input_ids"]:
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
|
||||
if not required_input:
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = []
|
||||
return encoded_inputs
|
||||
@@ -2648,14 +2654,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
# and rebuild them afterwards if no return_tensors is specified
|
||||
# Note that we lose the specific device the tensor may be on for PyTorch
|
||||
|
||||
first_element = encoded_inputs["input_ids"][0]
|
||||
first_element = required_input[0]
|
||||
if isinstance(first_element, (list, tuple)):
|
||||
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
|
||||
index = 0
|
||||
while len(encoded_inputs["input_ids"][index]) == 0:
|
||||
while len(required_input[index]) == 0:
|
||||
index += 1
|
||||
if index < len(encoded_inputs["input_ids"]):
|
||||
first_element = encoded_inputs["input_ids"][index][0]
|
||||
if index < len(required_input):
|
||||
first_element = required_input[index][0]
|
||||
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
||||
if not isinstance(first_element, (int, list, tuple)):
|
||||
if is_tf_available() and _is_tensorflow(first_element):
|
||||
@@ -2678,7 +2684,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
padding=padding, max_length=max_length, verbose=verbose
|
||||
)
|
||||
|
||||
if encoded_inputs["input_ids"] and not isinstance(encoded_inputs["input_ids"][0], (list, tuple)):
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
if required_input and not isinstance(required_input[0], (list, tuple)):
|
||||
encoded_inputs = self._pad(
|
||||
encoded_inputs,
|
||||
max_length=max_length,
|
||||
@@ -2688,13 +2695,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
)
|
||||
return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
|
||||
|
||||
batch_size = len(encoded_inputs["input_ids"])
|
||||
batch_size = len(required_input)
|
||||
assert all(
|
||||
len(v) == batch_size for v in encoded_inputs.values()
|
||||
), "Some items in the output dictionary have a different batch size than others."
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = max(len(inputs) for inputs in encoded_inputs["input_ids"])
|
||||
max_length = max(len(inputs) for inputs in required_input)
|
||||
padding_strategy = PaddingStrategy.MAX_LENGTH
|
||||
|
||||
batch_outputs = {}
|
||||
@@ -3004,42 +3011,42 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
if return_attention_mask is None:
|
||||
return_attention_mask = "attention_mask" in self.model_input_names
|
||||
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = len(encoded_inputs["input_ids"])
|
||||
max_length = len(required_input)
|
||||
|
||||
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
|
||||
needs_to_be_padded = (
|
||||
padding_strategy != PaddingStrategy.DO_NOT_PAD and len(encoded_inputs["input_ids"]) != max_length
|
||||
)
|
||||
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(encoded_inputs["input_ids"])
|
||||
difference = max_length - len(required_input)
|
||||
if self.padding_side == "right":
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference
|
||||
encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
|
||||
if "token_type_ids" in encoded_inputs:
|
||||
encoded_inputs["token_type_ids"] = (
|
||||
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
|
||||
)
|
||||
if "special_tokens_mask" in encoded_inputs:
|
||||
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
|
||||
encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
|
||||
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
|
||||
elif self.padding_side == "left":
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"])
|
||||
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
|
||||
if "token_type_ids" in encoded_inputs:
|
||||
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
||||
"token_type_ids"
|
||||
]
|
||||
if "special_tokens_mask" in encoded_inputs:
|
||||
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
||||
encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
|
||||
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
||||
else:
|
||||
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
||||
elif return_attention_mask and "attention_mask" not in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
|
||||
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user