Add preprocessing step for transfo-xl tokenization to avoid tokenizing words followed by punction to <unk> (#2987)

* add preprocessing to add space before punctuation for transfo_xl

* improve warning messages

* make style

* compile regex at instantination of tokenizer object
This commit is contained in:
Patrick von Platen
2020-02-24 21:11:10 +01:00
committed by GitHub
parent a143d9479e
commit 65d74c4965
2 changed files with 26 additions and 2 deletions

View File

@@ -59,7 +59,7 @@ MODEL_CLASSES = {
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
@@ -214,7 +214,9 @@ def main():
if requires_preprocessing:
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
encoded_prompt = tokenizer.encode(preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = tokenizer.encode(
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True
)
else:
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(args.device)