Sentence -> Sequence. Removed output_mask from the special token addition methods.
This commit is contained in:
@@ -75,7 +75,7 @@ class TextDataset(Dataset):
|
|||||||
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
||||||
|
|
||||||
while len(tokenized_text) >= block_size: # Truncate in block of block_size
|
while len(tokenized_text) >= block_size: # Truncate in block of block_size
|
||||||
self.examples.append(tokenizer.add_special_tokens_single_sentence(tokenized_text[:block_size]))
|
self.examples.append(tokenizer.add_special_tokens_single_sequence(tokenized_text[:block_size]))
|
||||||
tokenized_text = tokenized_text[block_size:]
|
tokenized_text = tokenized_text[block_size:]
|
||||||
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
|
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
|
||||||
# If your dataset is small, first you should loook for a bigger one :-) and second you
|
# If your dataset is small, first you should loook for a bigger one :-) and second you
|
||||||
|
|||||||
@@ -131,8 +131,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
text = tokenizer.encode("sequence builders")
|
text = tokenizer.encode("sequence builders")
|
||||||
text_2 = tokenizer.encode("multi-sequence build")
|
text_2 = tokenizer.encode("multi-sequence build")
|
||||||
|
|
||||||
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
|
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
|
||||||
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
|
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
|
||||||
|
|
||||||
assert encoded_sentence == [101] + text + [102]
|
assert encoded_sentence == [101] + text + [102]
|
||||||
assert encoded_pair == [101] + text + [102] + text_2 + [102]
|
assert encoded_pair == [101] + text + [102] + text_2 + [102]
|
||||||
|
|||||||
@@ -36,8 +36,8 @@ class DistilBertTokenizationTest(BertTokenizationTest):
|
|||||||
text = tokenizer.encode("sequence builders")
|
text = tokenizer.encode("sequence builders")
|
||||||
text_2 = tokenizer.encode("multi-sequence build")
|
text_2 = tokenizer.encode("multi-sequence build")
|
||||||
|
|
||||||
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
|
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
|
||||||
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
|
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
|
||||||
|
|
||||||
assert encoded_sentence == text
|
assert encoded_sentence == text
|
||||||
assert encoded_pair == text + [102] + text_2
|
assert encoded_pair == text + [102] + text_2
|
||||||
|
|||||||
@@ -87,8 +87,8 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True)
|
encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True)
|
||||||
encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True)
|
encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True)
|
||||||
|
|
||||||
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
|
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
|
||||||
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
|
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
|
||||||
|
|
||||||
assert encoded_sentence == encoded_text_from_decode
|
assert encoded_sentence == encoded_text_from_decode
|
||||||
assert encoded_pair == encoded_pair_from_decode
|
assert encoded_pair == encoded_pair_from_decode
|
||||||
|
|||||||
@@ -187,18 +187,18 @@ class CommonTestCases:
|
|||||||
for weights_list_2 in weights_lists_2:
|
for weights_list_2 in weights_lists_2:
|
||||||
self.assertListEqual(weights_list, weights_list_2)
|
self.assertListEqual(weights_list, weights_list_2)
|
||||||
|
|
||||||
def test_mask_output(self):
|
# def test_mask_output(self):
|
||||||
if sys.version_info <= (3, 0):
|
# if sys.version_info <= (3, 0):
|
||||||
return
|
# return
|
||||||
|
#
|
||||||
tokenizer = self.get_tokenizer()
|
# tokenizer = self.get_tokenizer()
|
||||||
|
#
|
||||||
if tokenizer.add_special_tokens_sentences_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer":
|
# if tokenizer.add_special_tokens_sequence_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer":
|
||||||
seq_0 = "Test this method."
|
# seq_0 = "Test this method."
|
||||||
seq_1 = "With these inputs."
|
# seq_1 = "With these inputs."
|
||||||
information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True, output_mask=True)
|
# information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True, output_mask=True)
|
||||||
sequences, mask = information["sequence"], information["mask"]
|
# sequences, mask = information["sequence"], information["mask"]
|
||||||
assert len(sequences) == len(mask)
|
# assert len(sequences) == len(mask)
|
||||||
|
|
||||||
def test_number_of_added_tokens(self):
|
def test_number_of_added_tokens(self):
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
@@ -228,7 +228,7 @@ class CommonTestCases:
|
|||||||
|
|
||||||
assert len(overflowing_tokens) == 2
|
assert len(overflowing_tokens) == 2
|
||||||
assert len(truncated_sequence) == total_length - 2
|
assert len(truncated_sequence) == total_length - 2
|
||||||
assert truncated_sequence == tokenizer.add_special_tokens_single_sentence(sequence[:-2])
|
assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
|
||||||
|
|
||||||
def test_maximum_encoding_length_pair_input(self):
|
def test_maximum_encoding_length_pair_input(self):
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
@@ -237,7 +237,7 @@ class CommonTestCases:
|
|||||||
seq_1 = "This is another sentence to be encoded."
|
seq_1 = "This is another sentence to be encoded."
|
||||||
|
|
||||||
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
|
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
|
||||||
truncated_second_sequence = tokenizer.add_special_tokens_sentences_pair(
|
truncated_second_sequence = tokenizer.add_special_tokens_sequence_pair(
|
||||||
tokenizer.encode(seq_0),
|
tokenizer.encode(seq_0),
|
||||||
tokenizer.encode(seq_1)[:-2]
|
tokenizer.encode(seq_1)[:-2]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -72,8 +72,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
text = tokenizer.encode("sequence builders")
|
text = tokenizer.encode("sequence builders")
|
||||||
text_2 = tokenizer.encode("multi-sequence build")
|
text_2 = tokenizer.encode("multi-sequence build")
|
||||||
|
|
||||||
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
|
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
|
||||||
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
|
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
|
||||||
|
|
||||||
assert encoded_sentence == [1] + text + [1]
|
assert encoded_sentence == [1] + text + [1]
|
||||||
assert encoded_pair == [1] + text + [1] + text_2 + [1]
|
assert encoded_pair == [1] + text + [1] + text_2 + [1]
|
||||||
|
|||||||
@@ -95,8 +95,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
text = tokenizer.encode("sequence builders")
|
text = tokenizer.encode("sequence builders")
|
||||||
text_2 = tokenizer.encode("multi-sequence build")
|
text_2 = tokenizer.encode("multi-sequence build")
|
||||||
|
|
||||||
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
|
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
|
||||||
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
|
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
|
||||||
|
|
||||||
assert encoded_sentence == text + [4, 3]
|
assert encoded_sentence == text + [4, 3]
|
||||||
assert encoded_pair == text + [4] + text_2 + [4, 3]
|
assert encoded_pair == text + [4] + text_2 + [4, 3]
|
||||||
|
|||||||
@@ -187,26 +187,21 @@ class BertTokenizer(PreTrainedTokenizer):
|
|||||||
out_string = ' '.join(tokens).replace(' ##', '').strip()
|
out_string = ' '.join(tokens).replace(' ##', '').strip()
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
def add_special_tokens_single_sentence(self, token_ids):
|
def add_special_tokens_single_sequence(self, token_ids):
|
||||||
"""
|
"""
|
||||||
Adds special tokens to the a sequence for sequence classification tasks.
|
Adds special tokens to the a sequence for sequence classification tasks.
|
||||||
A BERT sequence has the following format: [CLS] X [SEP]
|
A BERT sequence has the following format: [CLS] X [SEP]
|
||||||
"""
|
"""
|
||||||
return [self.cls_token_id] + token_ids + [self.sep_token_id]
|
return [self.cls_token_id] + token_ids + [self.sep_token_id]
|
||||||
|
|
||||||
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
|
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
|
||||||
"""
|
"""
|
||||||
Adds special tokens to a sequence pair for sequence classification tasks.
|
Adds special tokens to a sequence pair for sequence classification tasks.
|
||||||
A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]
|
A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]
|
||||||
"""
|
"""
|
||||||
sep = [self.sep_token_id]
|
sep = [self.sep_token_id]
|
||||||
cls = [self.cls_token_id]
|
cls = [self.cls_token_id]
|
||||||
if output_mask:
|
|
||||||
return (
|
|
||||||
cls + token_ids_0 + sep + token_ids_1 + sep,
|
|
||||||
[0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||||
|
|
||||||
def save_vocabulary(self, vocab_path):
|
def save_vocabulary(self, vocab_path):
|
||||||
|
|||||||
@@ -61,10 +61,10 @@ class DistilBertTokenizer(BertTokenizer):
|
|||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
def add_special_tokens_single_sentence(self, token_ids):
|
def add_special_tokens_single_sequence(self, token_ids):
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
|
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1, output_mask=False):
|
||||||
sep = [self.sep_token_id]
|
sep = [self.sep_token_id]
|
||||||
if output_mask:
|
if output_mask:
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -81,24 +81,18 @@ class RobertaTokenizer(GPT2Tokenizer):
|
|||||||
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
|
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
|
||||||
mask_token=mask_token, **kwargs)
|
mask_token=mask_token, **kwargs)
|
||||||
|
|
||||||
def add_special_tokens_single_sentence(self, token_ids):
|
def add_special_tokens_single_sequence(self, token_ids):
|
||||||
"""
|
"""
|
||||||
Adds special tokens to a sequence for sequence classification tasks.
|
Adds special tokens to a sequence for sequence classification tasks.
|
||||||
A RoBERTa sequence has the following format: <s> X </s>
|
A RoBERTa sequence has the following format: <s> X </s>
|
||||||
"""
|
"""
|
||||||
return [self.cls_token_id] + token_ids + [self.sep_token_id]
|
return [self.cls_token_id] + token_ids + [self.sep_token_id]
|
||||||
|
|
||||||
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
|
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
|
||||||
"""
|
"""
|
||||||
Adds special tokens to a sequence pair for sequence classification tasks.
|
Adds special tokens to a sequence pair for sequence classification tasks.
|
||||||
A RoBERTa sequence pair has the following format: <s> A </s></s> B </s>
|
A RoBERTa sequence pair has the following format: <s> A </s></s> B </s>
|
||||||
"""
|
"""
|
||||||
sep = [self.sep_token_id]
|
sep = [self.sep_token_id]
|
||||||
cls = [self.cls_token_id]
|
cls = [self.cls_token_id]
|
||||||
if output_mask:
|
|
||||||
return (
|
|
||||||
cls + token_ids_0 + sep + sep + token_ids_1 + sep,
|
|
||||||
[0] * len(cls + token_ids_0 + sep + sep) + [1] * len(token_ids_1 + sep)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
|
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
|
||||||
|
|||||||
@@ -708,7 +708,7 @@ class PreTrainedTokenizer(object):
|
|||||||
if text_pair is None:
|
if text_pair is None:
|
||||||
if add_special_tokens:
|
if add_special_tokens:
|
||||||
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||||
return self.add_special_tokens_single_sentence(sequence_tokens)
|
return self.add_special_tokens_single_sequence(sequence_tokens)
|
||||||
else:
|
else:
|
||||||
ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||||
return ids
|
return ids
|
||||||
@@ -717,7 +717,7 @@ class PreTrainedTokenizer(object):
|
|||||||
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
|
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
|
||||||
|
|
||||||
if add_special_tokens:
|
if add_special_tokens:
|
||||||
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
|
return self.add_special_tokens_sequence_pair(first_sentence_tokens, second_sentence_tokens)
|
||||||
else:
|
else:
|
||||||
logger.warning("No special tokens were added. The two sequences have been concatenated.")
|
logger.warning("No special tokens were added. The two sequences have been concatenated.")
|
||||||
return first_sentence_tokens + second_sentence_tokens
|
return first_sentence_tokens + second_sentence_tokens
|
||||||
@@ -747,7 +747,7 @@ class PreTrainedTokenizer(object):
|
|||||||
if max_length:
|
if max_length:
|
||||||
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens:]
|
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens:]
|
||||||
sequence_tokens = sequence_tokens[:max_length - n_added_tokens]
|
sequence_tokens = sequence_tokens[:max_length - n_added_tokens]
|
||||||
sequence = self.add_special_tokens_single_sentence(sequence_tokens)
|
sequence = self.add_special_tokens_single_sequence(sequence_tokens)
|
||||||
else:
|
else:
|
||||||
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||||
if max_length:
|
if max_length:
|
||||||
@@ -774,16 +774,13 @@ class PreTrainedTokenizer(object):
|
|||||||
information["overflowing_tokens"] = second_sentence_tokens[max_length - f_len - n_added_tokens:]
|
information["overflowing_tokens"] = second_sentence_tokens[max_length - f_len - n_added_tokens:]
|
||||||
second_sentence_tokens = second_sentence_tokens[:max_length - f_len - n_added_tokens]
|
second_sentence_tokens = second_sentence_tokens[:max_length - f_len - n_added_tokens]
|
||||||
|
|
||||||
encoded_sequence = self.add_special_tokens_sentences_pair(
|
sequence = self.add_special_tokens_sequence_pair(
|
||||||
first_sentence_tokens,
|
first_sentence_tokens,
|
||||||
second_sentence_tokens,
|
second_sentence_tokens
|
||||||
output_mask
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if output_mask:
|
# if output_mask:
|
||||||
sequence, information["mask"] = encoded_sequence
|
# sequence, information["mask"] = encoded_sequence
|
||||||
else:
|
|
||||||
sequence = encoded_sequence
|
|
||||||
|
|
||||||
information["sequence"] = sequence
|
information["sequence"] = sequence
|
||||||
else:
|
else:
|
||||||
@@ -800,11 +797,11 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
return information
|
return information
|
||||||
|
|
||||||
def add_special_tokens_single_sentence(self, token_ids):
|
def add_special_tokens_single_sequence(self, token_ids):
|
||||||
logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")
|
logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
|
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
|
||||||
logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.")
|
logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.")
|
||||||
return token_ids_0 + token_ids_1
|
return token_ids_0 + token_ids_1
|
||||||
|
|
||||||
|
|||||||
@@ -754,27 +754,20 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
def add_special_tokens_single_sentence(self, token_ids):
|
def add_special_tokens_single_sequence(self, token_ids):
|
||||||
"""
|
"""
|
||||||
Adds special tokens to a sequence for sequence classification tasks.
|
Adds special tokens to a sequence for sequence classification tasks.
|
||||||
An XLM sequence has the following format: [CLS] X [SEP]
|
An XLM sequence has the following format: [CLS] X [SEP]
|
||||||
"""
|
"""
|
||||||
return [self.cls_token_id] + token_ids + [self.sep_token_id]
|
return [self.cls_token_id] + token_ids + [self.sep_token_id]
|
||||||
|
|
||||||
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
|
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
|
||||||
"""
|
"""
|
||||||
Adds special tokens to a sequence pair for sequence classification tasks.
|
Adds special tokens to a sequence pair for sequence classification tasks.
|
||||||
An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP]
|
An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP]
|
||||||
"""
|
"""
|
||||||
sep = [self.sep_token_id]
|
sep = [self.sep_token_id]
|
||||||
cls = [self.cls_token_id]
|
cls = [self.cls_token_id]
|
||||||
|
|
||||||
if output_mask:
|
|
||||||
return (
|
|
||||||
cls + token_ids_0 + sep + token_ids_1 + sep,
|
|
||||||
[0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory):
|
def save_vocabulary(self, save_directory):
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
|||||||
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
|
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
def add_special_tokens_single_sentence(self, token_ids):
|
def add_special_tokens_single_sequence(self, token_ids):
|
||||||
"""
|
"""
|
||||||
Adds special tokens to a sequence pair for sequence classification tasks.
|
Adds special tokens to a sequence pair for sequence classification tasks.
|
||||||
An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS]
|
An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS]
|
||||||
@@ -190,7 +190,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
|||||||
cls = [self.cls_token_id]
|
cls = [self.cls_token_id]
|
||||||
return token_ids + sep + cls
|
return token_ids + sep + cls
|
||||||
|
|
||||||
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
|
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
|
||||||
"""
|
"""
|
||||||
Adds special tokens to a sequence for sequence classification tasks.
|
Adds special tokens to a sequence for sequence classification tasks.
|
||||||
An XLNet sequence has the following format: X [SEP][CLS]
|
An XLNet sequence has the following format: X [SEP][CLS]
|
||||||
@@ -199,12 +199,6 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
|||||||
sep = [self.sep_token_id]
|
sep = [self.sep_token_id]
|
||||||
cls = [self.cls_token_id]
|
cls = [self.cls_token_id]
|
||||||
cls_segment_ids = [2]
|
cls_segment_ids = [2]
|
||||||
if output_mask:
|
|
||||||
return (
|
|
||||||
token_ids_0 + sep + token_ids_1 + sep + cls,
|
|
||||||
[0] * len(token_ids_0 + sep) + [1] * len(token_ids_1 + sep) + cls_segment_ids
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return token_ids_0 + sep + token_ids_1 + sep + cls
|
return token_ids_0 + sep + token_ids_1 + sep + cls
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory):
|
def save_vocabulary(self, save_directory):
|
||||||
|
|||||||
Reference in New Issue
Block a user