FastPreTrainedTokenizer
This commit is contained in:
@@ -1410,3 +1410,130 @@ class PreTrainedTokenizer(object):
|
||||
.replace(" 're", "'re")
|
||||
)
|
||||
return out_string
|
||||
|
||||
class FastPreTrainedTokenizer(PreTrainedTokenizer):
|
||||
def __init__(self, **kwargs):
|
||||
super(FastPreTrainedTokenizer, self).__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
if self._tokenizer is None:
|
||||
raise NotImplementedError
|
||||
return self._tokenizer
|
||||
|
||||
@property
|
||||
def decoder(self):
|
||||
if self._decoder is None:
|
||||
raise NotImplementedError
|
||||
return self._decoder
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.get_vocab_size(False)
|
||||
|
||||
def __len__(self):
|
||||
return self.tokenizer.get_vocab_size(True)
|
||||
|
||||
def _update_special_tokens(self):
|
||||
self.tokenizer.add_special_tokens(self.all_special_tokens)
|
||||
|
||||
@staticmethod
|
||||
def _convert_encoding(encoding,
|
||||
return_tensors=None,
|
||||
return_token_type_ids=True,
|
||||
return_attention_mask=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_special_tokens_mask=False):
|
||||
encoding_dict = {
|
||||
"input_ids": encoding.ids,
|
||||
}
|
||||
if return_token_type_ids:
|
||||
encoding_dict["token_type_ids"] = encoding.type_ids
|
||||
if return_attention_mask:
|
||||
encoding_dict["attention_mask"] = encoding.attention_mask
|
||||
if return_overflowing_tokens:
|
||||
overflowing = encoding.overflowing
|
||||
encoding_dict["overflowing_tokens"] = overflowing.ids if overflowing is not None else []
|
||||
if return_special_tokens_mask:
|
||||
encoding_dict["special_tokens_mask"] = encoding.special_tokens_mask
|
||||
|
||||
# Prepare inputs as tensors if asked
|
||||
if return_tensors == 'tf' and is_tf_available():
|
||||
encoding_dict["input_ids"] = tf.constant([encoding_dict["input_ids"]])
|
||||
encoding_dict["token_type_ids"] = tf.constant([encoding_dict["token_type_ids"]])
|
||||
|
||||
if "attention_mask" in encoding_dict:
|
||||
encoding_dict["attention_mask"] = tf.constant([encoding_dict["attention_mask"]])
|
||||
|
||||
elif return_tensors == 'pt' and is_torch_available():
|
||||
encoding_dict["input_ids"] = torch.tensor([encoding_dict["input_ids"]])
|
||||
encoding_dict["token_type_ids"] = torch.tensor([encoding_dict["token_type_ids"]])
|
||||
|
||||
if "attention_mask" in encoding_dict:
|
||||
encoding_dict["attention_mask"] = torch.tensor([encoding_dict["attention_mask"]])
|
||||
elif return_tensors is not None:
|
||||
logger.warning(
|
||||
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
|
||||
return_tensors))
|
||||
|
||||
return encoding_dict
|
||||
|
||||
def encode_plus(self,
|
||||
text,
|
||||
text_pair=None,
|
||||
return_tensors=None,
|
||||
return_token_type_ids=True,
|
||||
return_attention_mask=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_special_tokens_mask=False,
|
||||
**kwargs):
|
||||
encoding = self.tokenizer.encode(text, text_pair)
|
||||
return self._convert_encoding(encoding,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask)
|
||||
|
||||
def tokenize(self, text):
|
||||
return self.tokenizer.encode(text).tokens
|
||||
|
||||
def _convert_token_to_id_with_added_voc(self, token):
|
||||
return self.tokenizer.token_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
return self.tokenizer.id_to_token(int(index))
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
return self.decoder.decode(tokens)
|
||||
|
||||
def add_tokens(self, new_tokens):
|
||||
self.tokenizer.add_tokens(new_tokens)
|
||||
|
||||
def encode_batch(self, texts,
|
||||
return_tensors=None,
|
||||
return_token_type_ids=True,
|
||||
return_attention_mask=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_special_tokens_mask=False):
|
||||
return [self._convert_encoding(encoding,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask)
|
||||
for encoding in self.tokenizer.encode_batch(texts)]
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||
text = self.tokenizer.decode(token_ids, skip_special_tokens)
|
||||
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = self.clean_up_tokenization(text)
|
||||
return clean_text
|
||||
else:
|
||||
return text
|
||||
|
||||
def decode_batch(self, ids_batch, skip_special_tokens=False, clear_up_tokenization_spaces=True):
|
||||
return [self.clean_up_tokenization(text)
|
||||
if clear_up_tokenization_spaces else text
|
||||
for text in self.tokenizer.decode_batch(ids_batch, skip_special_tokens)]
|
||||
Reference in New Issue
Block a user