Add preprocessing step for transfo-xl tokenization to avoid tokenizing words followed by punction to <unk> (#2987)

* add preprocessing to add space before punctuation for transfo_xl

* improve warning messages

* make style

* compile regex at instantination of tokenizer object
This commit is contained in:
Patrick von Platen
2020-02-24 21:11:10 +01:00
committed by GitHub
parent a143d9479e
commit 65d74c4965
2 changed files with 26 additions and 2 deletions

View File

@@ -214,7 +214,9 @@ def main():
if requires_preprocessing: if requires_preprocessing:
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
encoded_prompt = tokenizer.encode(preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt") encoded_prompt = tokenizer.encode(
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True
)
else: else:
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(args.device) encoded_prompt = encoded_prompt.to(args.device)

View File

@@ -22,6 +22,7 @@ import glob
import logging import logging
import os import os
import pickle import pickle
import re
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -114,6 +115,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.delimiter = delimiter self.delimiter = delimiter
self.vocab_file = vocab_file self.vocab_file = vocab_file
self.never_split = never_split self.never_split = never_split
self.punctuation_symbols = '!"#$%&()*+,-./\:;<=>?@[\\]^_`{|}~' # noqa: W605
self.punction_without_space_before_pattern = re.compile(r"[^\s][{}]".format(self.punctuation_symbols))
self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern()
if pretrained_vocab_file is not None: if pretrained_vocab_file is not None:
# Hack because, honestly this tokenizer was not made to be used # Hack because, honestly this tokenizer was not made to be used
@@ -126,6 +130,11 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if vocab_file is not None: if vocab_file is not None:
self.build_vocab() self.build_vocab()
def _compile_space_around_punctuation_pattern(self):
look_ahead_for_special_token = "(?=[{}])".format(self.punctuation_symbols)
look_ahead_to_match_all_except_space = "(?=[^\s])" # noqa: W605
return re.compile(r"" + look_ahead_for_special_token + look_ahead_to_match_all_except_space)
def count_file(self, path, verbose=False, add_eos=False): def count_file(self, path, verbose=False, add_eos=False):
if verbose: if verbose:
logger.info("counting file {} ...".format(path)) logger.info("counting file {} ...".format(path))
@@ -295,6 +304,19 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
else: else:
return symbols return symbols
def prepare_for_tokenization(self, text, **kwargs):
# add spaces before punctuation symbols as should be done in transfo-xl
if "add_space_before_punct_symbol" in kwargs and kwargs["add_space_before_punct_symbol"]:
text = self.punctuation_with_space_around_pattern.sub(r" ", text)
elif self.punction_without_space_before_pattern.search(text):
# searches until the first occurence of a punctuation symbol without surrounding spaces
logger.warning(
"You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `<unk>` token"
)
return text
class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer): class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
def __init__( def __init__(