Add preprocessing step for transfo-xl tokenization to avoid tokenizing words followed by punction to <unk> (#2987)

* add preprocessing to add space before punctuation for transfo_xl

* improve warning messages

* make style

* compile regex at instantination of tokenizer object
This commit is contained in:
Patrick von Platen
2020-02-24 21:11:10 +01:00
committed by GitHub
parent a143d9479e
commit 65d74c4965
2 changed files with 26 additions and 2 deletions

View File

@@ -214,7 +214,9 @@ def main():
if requires_preprocessing:
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
encoded_prompt = tokenizer.encode(preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = tokenizer.encode(
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True
)
else:
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(args.device)

View File

@@ -22,6 +22,7 @@ import glob
import logging
import os
import pickle
import re
from collections import Counter, OrderedDict
from typing import List, Optional, Tuple, Union
@@ -114,6 +115,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.delimiter = delimiter
self.vocab_file = vocab_file
self.never_split = never_split
self.punctuation_symbols = '!"#$%&()*+,-./\:;<=>?@[\\]^_`{|}~' # noqa: W605
self.punction_without_space_before_pattern = re.compile(r"[^\s][{}]".format(self.punctuation_symbols))
self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern()
if pretrained_vocab_file is not None:
# Hack because, honestly this tokenizer was not made to be used
@@ -126,6 +130,11 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if vocab_file is not None:
self.build_vocab()
def _compile_space_around_punctuation_pattern(self):
look_ahead_for_special_token = "(?=[{}])".format(self.punctuation_symbols)
look_ahead_to_match_all_except_space = "(?=[^\s])" # noqa: W605
return re.compile(r"" + look_ahead_for_special_token + look_ahead_to_match_all_except_space)
def count_file(self, path, verbose=False, add_eos=False):
if verbose:
logger.info("counting file {} ...".format(path))
@@ -295,6 +304,19 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
else:
return symbols
def prepare_for_tokenization(self, text, **kwargs):
# add spaces before punctuation symbols as should be done in transfo-xl
if "add_space_before_punct_symbol" in kwargs and kwargs["add_space_before_punct_symbol"]:
text = self.punctuation_with_space_around_pattern.sub(r" ", text)
elif self.punction_without_space_before_pattern.search(text):
# searches until the first occurence of a punctuation symbol without surrounding spaces
logger.warning(
"You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `<unk>` token"
)
return text
class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
def __init__(