update tokenizer - update squad example for xlnet
This commit is contained in:
@@ -528,9 +528,9 @@ class PoolerEndLogits(nn.Module):
|
||||
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
||||
1.0 means token should be masked.
|
||||
"""
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||
if start_positions is not None:
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
|
||||
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
|
||||
@@ -571,7 +571,7 @@ class PoolerAnswerClass(nn.Module):
|
||||
no dependency on end_feature so that we can obtain one single `cls_logits`
|
||||
for each sample
|
||||
"""
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
hsz = hidden_states.shape[-1]
|
||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||
if start_positions is not None:
|
||||
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||
@@ -614,12 +614,21 @@ class SQuADHead(nn.Module):
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
||||
**last_hidden_state**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) `torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the last layer of the model.
|
||||
**mems**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
|
||||
**start_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
||||
**start_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||
Indices for the top config.start_n_top start token possibilities (beam-search).
|
||||
**end_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
**end_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
**cls_logits**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size,)``
|
||||
Log probabilities for the ``is_impossible`` label of the answers.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(SQuADHead, self).__init__()
|
||||
@@ -667,8 +676,8 @@ class SQuADHead(nn.Module):
|
||||
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
||||
|
||||
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
|
||||
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz)
|
||||
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
|
||||
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
||||
|
||||
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
|
||||
|
||||
@@ -1167,12 +1167,23 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
1.0 means token should be masked. 0.0 mean token is not masked.
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-start scores (before SoftMax).
|
||||
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-end scores (before SoftMax).
|
||||
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
||||
**start_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
||||
**start_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||
Indices for the top config.start_n_top start token possibilities (beam-search).
|
||||
**end_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
**end_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
**cls_logits**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size,)``
|
||||
Log probabilities for the ``is_impossible`` label of the answers.
|
||||
**mems**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
@@ -1243,12 +1254,10 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
loss_fct_cls = nn.BCEWithLogitsLoss()
|
||||
cls_loss = loss_fct_cls(cls_logits, is_impossible)
|
||||
|
||||
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is
|
||||
# comparable to start_loss and end_loss
|
||||
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
||||
total_loss += cls_loss * 0.5
|
||||
outputs = (total_loss, start_logits, end_logits, cls_logits) + outputs
|
||||
else:
|
||||
outputs = (total_loss, start_logits, end_logits) + outputs
|
||||
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
else:
|
||||
# during inference, compute the end logits based on beam search
|
||||
@@ -1256,8 +1265,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
||||
|
||||
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
|
||||
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz)
|
||||
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
|
||||
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
||||
|
||||
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
|
||||
@@ -1269,11 +1278,11 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
|
||||
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
|
||||
|
||||
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
|
||||
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
|
||||
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) # get the representation of START as weighted sum of hidden states
|
||||
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) # Shape (batch size,): one single `cls_logits` for each sample
|
||||
|
||||
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
|
||||
|
||||
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems, (hidden states), (attentions)
|
||||
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits), mems, (hidden states), (attentions)
|
||||
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
|
||||
# or (if labels are provided) (total_loss,)
|
||||
return outputs
|
||||
|
||||
@@ -38,7 +38,10 @@ class TokenizationTest(unittest.TestCase):
|
||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
create_and_check_tokenizer_commons(self, BertTokenizer, tmpdirname)
|
||||
input_text = u"UNwant\u00E9d,running"
|
||||
output_text = u"unwanted, running"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, BertTokenizer, tmpdirname)
|
||||
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
|
||||
|
||||
@@ -41,7 +41,10 @@ class GPT2TokenizationTest(unittest.TestCase):
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
create_and_check_tokenizer_commons(self, GPT2Tokenizer, tmpdirname, **special_tokens_map)
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower<unk>newer"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, GPT2Tokenizer, tmpdirname, **special_tokens_map)
|
||||
|
||||
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
|
||||
text = "lower"
|
||||
|
||||
@@ -42,7 +42,10 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, tmpdirname)
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower newer"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, OpenAIGPTTokenizer, tmpdirname)
|
||||
|
||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
|
||||
|
||||
|
||||
@@ -113,23 +113,24 @@ def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kw
|
||||
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
|
||||
|
||||
|
||||
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
def create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||
|
||||
text = u"He is very happy, UNwant\u00E9d,running"
|
||||
tokens = tokenizer.tokenize(text)
|
||||
tokens = tokenizer.tokenize(input_text)
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
ids_2 = tokenizer.encode(text)
|
||||
ids_2 = tokenizer.encode(input_text)
|
||||
tester.assertListEqual(ids, ids_2)
|
||||
|
||||
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
|
||||
text_2 = tokenizer.decode(ids)
|
||||
|
||||
tester.assertEqual(text_2, output_text)
|
||||
|
||||
tester.assertNotEqual(len(tokens_2), 0)
|
||||
tester.assertIsInstance(text_2, (str, unicode))
|
||||
|
||||
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
|
||||
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
def create_and_check_tokenizer_commons(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
|
||||
create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
|
||||
@@ -34,7 +34,10 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, tmpdirname, lower_case=True)
|
||||
input_text = u"<unk> UNwanted , running"
|
||||
output_text = u"<unk> unwanted, running"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, TransfoXLTokenizer, tmpdirname, lower_case=True)
|
||||
|
||||
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
|
||||
|
||||
|
||||
@@ -41,7 +41,10 @@ class XLMTokenizationTest(unittest.TestCase):
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
create_and_check_tokenizer_commons(self, XLMTokenizer, tmpdirname)
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower newer"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname)
|
||||
|
||||
tokenizer = XLMTokenizer(vocab_file, merges_file)
|
||||
|
||||
|
||||
@@ -32,7 +32,10 @@ class XLNetTokenizationTest(unittest.TestCase):
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
|
||||
create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname)
|
||||
input_text = u"This is a test"
|
||||
output_text = u"This is a test"
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname)
|
||||
|
||||
tokens = tokenizer.tokenize(u'This is a test')
|
||||
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
||||
|
||||
@@ -161,10 +161,9 @@ class BertTokenizer(PreTrainedTokenizer):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
return self.ids_to_tokens.get(index, self.unk_token)
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
tokens = self.convert_ids_to_tokens(tokens_ids)
|
||||
out_string = ''.join(tokens).replace(' ##', '').strip()
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = ' '.join(tokens).replace(' ##', '').strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
|
||||
@@ -185,9 +185,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
return self.decoder.get(index)
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
text = ''.join(tokens_ids)
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
text = ''.join(tokens)
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||
return text
|
||||
|
||||
|
||||
@@ -174,9 +174,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
|
||||
"""Converts an id in a token (BPE) using the vocab."""
|
||||
return self.decoder.get(index, self.unk_token)
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
out_string = ''.join(tokens_ids).replace('</w>', ' ').strip()
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
|
||||
@@ -229,9 +229,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
else:
|
||||
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
out_string = ' '.join(tokens_ids).strip()
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = ' '.join(tokens).strip()
|
||||
return out_string
|
||||
|
||||
def convert_to_tensor(self, symbols):
|
||||
|
||||
@@ -361,52 +361,26 @@ class PreTrainedTokenizer(object):
|
||||
(resp.) a sequence of ids, using the vocabulary.
|
||||
"""
|
||||
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
|
||||
return self.convert_token_to_id_with_added_voc(tokens)
|
||||
return self._convert_token_to_id_with_added_voc(tokens)
|
||||
|
||||
ids = []
|
||||
for token in tokens:
|
||||
ids.append(self.convert_token_to_id_with_added_voc(token))
|
||||
ids.append(self._convert_token_to_id_with_added_voc(token))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
|
||||
"for this model ({} > {}). Running this sequence through the model will result in "
|
||||
"indexing errors".format(len(ids), self.max_len))
|
||||
return ids
|
||||
|
||||
|
||||
def convert_token_to_id_with_added_voc(self, token):
|
||||
def _convert_token_to_id_with_added_voc(self, token):
|
||||
if token in self.added_tokens_encoder:
|
||||
return self.added_tokens_encoder[token]
|
||||
return self._convert_token_to_id(token)
|
||||
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
""" Converts a single index or a sequence of indices (integers) in a token "
|
||||
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
|
||||
|
||||
Args:
|
||||
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
|
||||
"""
|
||||
if isinstance(ids, int):
|
||||
return self.convert_id_to_token(ids)
|
||||
tokens = []
|
||||
for index in ids:
|
||||
if index in self.all_special_ids and skip_special_tokens:
|
||||
continue
|
||||
if index in self.added_tokens_decoder:
|
||||
tokens.append(self.added_tokens_decoder[index])
|
||||
else:
|
||||
tokens.append(self._convert_id_to_token(index))
|
||||
return tokens
|
||||
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def encode(self, text):
|
||||
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
same as self.convert_tokens_to_ids(self.tokenize(text)).
|
||||
@@ -414,22 +388,48 @@ class PreTrainedTokenizer(object):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
""" Converts a single index or a sequence of indices (integers) in a token "
|
||||
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
|
||||
|
||||
Args:
|
||||
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
|
||||
"""
|
||||
if isinstance(ids, int):
|
||||
if ids in self.added_tokens_decoder:
|
||||
return self.added_tokens_decoder[ids]
|
||||
else:
|
||||
return self._convert_id_to_token(ids)
|
||||
tokens = []
|
||||
for index in ids:
|
||||
if index in self.all_special_ids and skip_special_tokens:
|
||||
continue
|
||||
if index in self.added_tokens_decoder:
|
||||
tokens.append(self.added_tokens_decoder[index])
|
||||
else:
|
||||
tokens.append(self._convert_id_to_token(index))
|
||||
return tokens
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string.
|
||||
The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
|
||||
but we often want to remove sub-word tokenization artifacts at the same time.
|
||||
"""
|
||||
return ' '.join(self.convert_ids_to_tokens(tokens))
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
|
||||
with options to remove special tokens and clean up tokenization spaces.
|
||||
"""
|
||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
text = self._convert_ids_to_string(filtered_tokens)
|
||||
text = self.convert_tokens_to_string(filtered_tokens)
|
||||
if clean_up_tokenization_spaces:
|
||||
text = clean_up_tokenization(text)
|
||||
return text
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary.
|
||||
roughtly same as ' '.join(self.convert_ids_to_tokens(token_ids)).
|
||||
"""
|
||||
return ' '.join(self.convert_ids_to_tokens(tokens_ids))
|
||||
|
||||
@property
|
||||
def special_tokens_map(self):
|
||||
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
|
||||
|
||||
@@ -202,9 +202,9 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
return self.decoder.get(index, self.unk_token)
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
out_string = ''.join(tokens_ids).replace('</w>', ' ').strip()
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
|
||||
@@ -170,9 +170,9 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||
token = token.decode('utf-8')
|
||||
return token
|
||||
|
||||
def _convert_ids_to_string(self, tokens_ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
out_string = ''.join(tokens_ids).replace(SPIECE_UNDERLINE, ' ')
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
||||
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
@@ -184,6 +184,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||
return
|
||||
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
Reference in New Issue
Block a user