Fixed inconsistency in several fast tokenizers (#26561)
This commit is contained in:
@@ -265,7 +265,7 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
"""
|
"""
|
||||||
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1 is not None:
|
||||||
output += token_ids_1 + [self.sep_token_id]
|
output += token_ids_1 + [self.sep_token_id]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ class ConvBertTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
"""
|
"""
|
||||||
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1 is not None:
|
||||||
output += token_ids_1 + [self.sep_token_id]
|
output += token_ids_1 + [self.sep_token_id]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ class RetriBertTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
"""
|
"""
|
||||||
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1 is not None:
|
||||||
output += token_ids_1 + [self.sep_token_id]
|
output += token_ids_1 + [self.sep_token_id]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ class DistilBertTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
"""
|
"""
|
||||||
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1 is not None:
|
||||||
output += token_ids_1 + [self.sep_token_id]
|
output += token_ids_1 + [self.sep_token_id]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ class ElectraTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
"""
|
"""
|
||||||
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1 is not None:
|
||||||
output += token_ids_1 + [self.sep_token_id]
|
output += token_ids_1 + [self.sep_token_id]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -212,7 +212,7 @@ class FunnelTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
"""
|
"""
|
||||||
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1 is not None:
|
||||||
output += token_ids_1 + [self.sep_token_id]
|
output += token_ids_1 + [self.sep_token_id]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ class LayoutLMTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
"""
|
"""
|
||||||
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1 is not None:
|
||||||
output += token_ids_1 + [self.sep_token_id]
|
output += token_ids_1 + [self.sep_token_id]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ class LxmertTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
"""
|
"""
|
||||||
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1 is not None:
|
||||||
output += token_ids_1 + [self.sep_token_id]
|
output += token_ids_1 + [self.sep_token_id]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ class MobileBertTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
"""
|
"""
|
||||||
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1 is not None:
|
||||||
output += token_ids_1 + [self.sep_token_id]
|
output += token_ids_1 + [self.sep_token_id]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -282,7 +282,7 @@ class RealmTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
"""
|
"""
|
||||||
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1 is not None:
|
||||||
output += token_ids_1 + [self.sep_token_id]
|
output += token_ids_1 + [self.sep_token_id]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ class RoFormerTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
"""
|
"""
|
||||||
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1 is not None:
|
||||||
output += token_ids_1 + [self.sep_token_id]
|
output += token_ids_1 + [self.sep_token_id]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ class SqueezeBertTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
"""
|
"""
|
||||||
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||||
|
|
||||||
if token_ids_1:
|
if token_ids_1 is not None:
|
||||||
output += token_ids_1 + [self.sep_token_id]
|
output += token_ids_1 + [self.sep_token_id]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -3209,9 +3209,17 @@ class TokenizerTesterMixin:
|
|||||||
# output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
|
# output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
|
||||||
# self.assertEqual(output_p, output_r)
|
# self.assertEqual(output_p, output_r)
|
||||||
|
|
||||||
|
input_pairs = [
|
||||||
|
("", ""),
|
||||||
|
("", "This is a sample pair"),
|
||||||
|
("This is a sample input", ""),
|
||||||
|
("This is a sample input", "This is a sample pair"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for sample_input, sample_pair in input_pairs:
|
||||||
# Input tokens id
|
# Input tokens id
|
||||||
input_simple = tokenizer_p.encode("This is a sample input", add_special_tokens=False)
|
input_simple = tokenizer_p.encode(sample_input, add_special_tokens=False)
|
||||||
input_pair = tokenizer_p.encode("This is a sample pair", add_special_tokens=False)
|
input_pair = tokenizer_p.encode(sample_pair, add_special_tokens=False)
|
||||||
|
|
||||||
# Generate output
|
# Generate output
|
||||||
output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple)
|
output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple)
|
||||||
|
|||||||
Reference in New Issue
Block a user