Honor existing attention mask in tokenzier.pad (#13926)

* Honor existing attention mask in tokenzier.pad

* Fix initialization of attention mask

* Roll the implem on all subclasses

* Fix tests
This commit is contained in:
Sylvain Gugger
2021-10-11 09:12:09 -04:00
committed by GitHub
parent 3c0c699ffd
commit 4a18337bae
7 changed files with 68 additions and 38 deletions

View File

@@ -3110,11 +3110,17 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
# Initialize attention mask if not present.
if return_attention_mask and "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * len(required_input)
if needs_to_be_padded:
difference = max_length - len(required_input)
if self.padding_side == "right":
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
@@ -3124,7 +3130,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
elif self.padding_side == "left":
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
"token_type_ids"
@@ -3134,8 +3140,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
elif return_attention_mask and "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * len(required_input)
return encoded_inputs