Add preprocessing step for transfo-xl tokenization to avoid tokenizing words followed by punction to <unk> (#2987)
* add preprocessing to add space before punctuation for transfo_xl * improve warning messages * make style * compile regex at instantination of tokenizer object
This commit is contained in:
committed by
GitHub
parent
a143d9479e
commit
65d74c4965
@@ -59,7 +59,7 @@ MODEL_CLASSES = {
|
|||||||
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
||||||
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
||||||
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
||||||
PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
|
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
|
||||||
(except for Alexei and Maria) are discovered.
|
(except for Alexei and Maria) are discovered.
|
||||||
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
||||||
remainder of the story. 1883 Western Siberia,
|
remainder of the story. 1883 Western Siberia,
|
||||||
@@ -214,7 +214,9 @@ def main():
|
|||||||
if requires_preprocessing:
|
if requires_preprocessing:
|
||||||
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
||||||
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
|
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
|
||||||
encoded_prompt = tokenizer.encode(preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt")
|
encoded_prompt = tokenizer.encode(
|
||||||
|
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||||
encoded_prompt = encoded_prompt.to(args.device)
|
encoded_prompt = encoded_prompt.to(args.device)
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import glob
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
import re
|
||||||
from collections import Counter, OrderedDict
|
from collections import Counter, OrderedDict
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -114,6 +115,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
self.delimiter = delimiter
|
self.delimiter = delimiter
|
||||||
self.vocab_file = vocab_file
|
self.vocab_file = vocab_file
|
||||||
self.never_split = never_split
|
self.never_split = never_split
|
||||||
|
self.punctuation_symbols = '!"#$%&()*+,-./\:;<=>?@[\\]^_`{|}~' # noqa: W605
|
||||||
|
self.punction_without_space_before_pattern = re.compile(r"[^\s][{}]".format(self.punctuation_symbols))
|
||||||
|
self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern()
|
||||||
|
|
||||||
if pretrained_vocab_file is not None:
|
if pretrained_vocab_file is not None:
|
||||||
# Hack because, honestly this tokenizer was not made to be used
|
# Hack because, honestly this tokenizer was not made to be used
|
||||||
@@ -126,6 +130,11 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
if vocab_file is not None:
|
if vocab_file is not None:
|
||||||
self.build_vocab()
|
self.build_vocab()
|
||||||
|
|
||||||
|
def _compile_space_around_punctuation_pattern(self):
|
||||||
|
look_ahead_for_special_token = "(?=[{}])".format(self.punctuation_symbols)
|
||||||
|
look_ahead_to_match_all_except_space = "(?=[^\s])" # noqa: W605
|
||||||
|
return re.compile(r"" + look_ahead_for_special_token + look_ahead_to_match_all_except_space)
|
||||||
|
|
||||||
def count_file(self, path, verbose=False, add_eos=False):
|
def count_file(self, path, verbose=False, add_eos=False):
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info("counting file {} ...".format(path))
|
logger.info("counting file {} ...".format(path))
|
||||||
@@ -295,6 +304,19 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
|||||||
else:
|
else:
|
||||||
return symbols
|
return symbols
|
||||||
|
|
||||||
|
def prepare_for_tokenization(self, text, **kwargs):
|
||||||
|
# add spaces before punctuation symbols as should be done in transfo-xl
|
||||||
|
|
||||||
|
if "add_space_before_punct_symbol" in kwargs and kwargs["add_space_before_punct_symbol"]:
|
||||||
|
text = self.punctuation_with_space_around_pattern.sub(r" ", text)
|
||||||
|
elif self.punction_without_space_before_pattern.search(text):
|
||||||
|
# searches until the first occurence of a punctuation symbol without surrounding spaces
|
||||||
|
logger.warning(
|
||||||
|
"You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `<unk>` token"
|
||||||
|
)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
|
class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Reference in New Issue
Block a user