Honor existing attention mask in tokenzier.pad (#13926)
* Honor existing attention mask in tokenzier.pad * Fix initialization of attention mask * Roll the implem on all subclasses * Fix tests
This commit is contained in:
@@ -3110,11 +3110,17 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
|
||||
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
||||
|
||||
# Initialize attention mask if not present.
|
||||
if return_attention_mask and "attention_mask" not in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(required_input)
|
||||
|
||||
if self.padding_side == "right":
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
|
||||
|
||||
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
|
||||
if "token_type_ids" in encoded_inputs:
|
||||
encoded_inputs["token_type_ids"] = (
|
||||
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
|
||||
@@ -3124,7 +3130,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
|
||||
elif self.padding_side == "left":
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
|
||||
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
||||
if "token_type_ids" in encoded_inputs:
|
||||
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
||||
"token_type_ids"
|
||||
@@ -3134,8 +3140,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
||||
else:
|
||||
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
||||
elif return_attention_mask and "attention_mask" not in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user