From b514a60c360194b9f78f7dbee9dd8fbdf54ff688 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 11 Feb 2019 10:17:16 +0100 Subject: [PATCH] added tests for OpenAI GPT and Transformer-XL tokenizers --- README.md | 8 +- .../tokenization_openai.py | 9 +- .../tokenization_transfo_xl.py | 147 +++++++++++++++--- tests/tokenization_openai_test.py | 57 +++++++ tests/tokenization_transfo_xl_test.py | 90 +++++++++++ 5 files changed, 286 insertions(+), 25 deletions(-) create mode 100644 tests/tokenization_openai_test.py create mode 100644 tests/tokenization_transfo_xl_test.py diff --git a/README.md b/README.md index df8fbccb2b..607ab3b689 100644 --- a/README.md +++ b/README.md @@ -529,10 +529,10 @@ This model *outputs*: `OpenAIGPTDoubleHeadsModel` includes the `OpenAIGPTModel` Transformer followed by two heads: - a language modeling head with weights tied to the input embeddings (no additional parameters) and: -- a multiple choice classifier (linear layer). +- a multiple choice classifier (linear layer that take as input a hidden state in a sequence to compute a score, see details in paper). *Inputs* are the same as the inputs of the [`OpenAIGPTModel`](#-9.-`OpenAIGPTModel`) class plus a classification mask and two optional labels: -- `multiple_choice_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise. +- `multiple_choice_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token whose hidden state should be used as input for the multiple choice classifier (usually the [CLS] token for each choice). - `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size]. - `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_choices]. @@ -613,9 +613,9 @@ Please refer to the doc strings and code in [`tokenization_openai.py`](./pytorch #### `TransfoXLTokenizer` -`TransfoXLTokenizer` perform word tokenization. +`TransfoXLTokenizer` perform word tokenization. This tokenizer can be used for adaptive softmax and has utilities for counting tokens in a corpus to create a vocabulary ordered by toekn frequency (for adaptive softmax). See the adaptive softmax paper ([Efficient softmax approximation for GPUs](http://arxiv.org/abs/1609.04309)) for more details. -Please refer to the doc strings and code in [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) for the details of the `TransfoXLTokenizer`. +Please refer to the doc strings and code in [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) for the details of these additional methods in `TransfoXLTokenizer`. ### Optimizers: diff --git a/pytorch_pretrained_bert/tokenization_openai.py b/pytorch_pretrained_bert/tokenization_openai.py index fcb8e13949..77ba922856 100644 --- a/pytorch_pretrained_bert/tokenization_openai.py +++ b/pytorch_pretrained_bert/tokenization_openai.py @@ -70,7 +70,10 @@ def text_standardize(text): class OpenAIGPTTokenizer(object): """ - mostly a wrapper for a public python bpe tokenizer + BPE tokenizer. Peculiarities: + - lower case all inputs + - uses SpaCy tokenizer + - special tokens: additional symbols (ex: "__classify__") to add to a vocabulary. """ @classmethod def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): @@ -150,7 +153,7 @@ class OpenAIGPTTokenizer(object): logger.info("Special tokens {}".format(self.special_tokens)) def bpe(self, token): - word = tuple(token[:-1]) + ( token[-1] + '',) + word = tuple(token[:-1]) + (token[-1] + '',) if token in self.cache: return self.cache[token] pairs = get_pairs(word) @@ -159,7 +162,7 @@ class OpenAIGPTTokenizer(object): return token+'' while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) if bigram not in self.bpe_ranks: break first, second = bigram diff --git a/pytorch_pretrained_bert/tokenization_transfo_xl.py b/pytorch_pretrained_bert/tokenization_transfo_xl.py index 860b274f19..698deae21c 100644 --- a/pytorch_pretrained_bert/tokenization_transfo_xl.py +++ b/pytorch_pretrained_bert/tokenization_transfo_xl.py @@ -25,6 +25,7 @@ import os import sys from collections import Counter, OrderedDict from io import open +import unicodedata import torch import numpy as np @@ -89,8 +90,8 @@ class TransfoXLTokenizer(object): tokenizer.__dict__[key] = value return tokenizer - def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, - delimiter=None, vocab_file=None): + def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False, + delimiter=None, vocab_file=None, never_split=("", "", "")): self.counter = Counter() self.special = special self.min_freq = min_freq @@ -98,6 +99,7 @@ class TransfoXLTokenizer(object): self.lower_case = lower_case self.delimiter = delimiter self.vocab_file = vocab_file + self.never_split = never_split def count_file(self, path, verbose=False, add_eos=False): if verbose: print('counting file {} ...'.format(path)) @@ -132,7 +134,12 @@ class TransfoXLTokenizer(object): for line in f: symb = line.strip().split()[0] self.add_symbol(symb) - self.unk_idx = self.sym2idx[''] + if '' in self.sym2idx: + self.unk_idx = self.sym2idx[''] + elif '' in self.sym2idx: + self.unk_idx = self.sym2idx[''] + else: + raise ValueError('No token in vocabulary') def build_vocab(self): if self.vocab_file: @@ -198,7 +205,7 @@ class TransfoXLTokenizer(object): self.sym2idx[sym] = len(self.idx2sym) - 1 def get_sym(self, idx): - assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) + assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx) return self.idx2sym[idx] def get_idx(self, sym): @@ -206,9 +213,16 @@ class TransfoXLTokenizer(object): return self.sym2idx[sym] else: # print('encounter unk {}'.format(sym)) - assert '' not in sym - assert hasattr(self, 'unk_idx') - return self.sym2idx.get(sym, self.unk_idx) + # assert '' not in sym + if hasattr(self, 'unk_idx'): + return self.sym2idx.get(sym, self.unk_idx) + # Backward compatibility with pre-trained models + elif '' in self.sym2idx: + return self.sym2idx[''] + elif '' in self.sym2idx: + return self.sym2idx[''] + else: + raise ValueError('Token not in vocabulary and no token in vocabulary for replacement') def convert_ids_to_tokens(self, indices): """Converts a sequence of indices in symbols using the vocab.""" @@ -231,24 +245,82 @@ class TransfoXLTokenizer(object): def __len__(self): return len(self.idx2sym) - def tokenize(self, line, add_eos=False, add_double_eos=False): - line = line.strip() - # convert to lower case - if self.lower_case: - line = line.lower() + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + if text in self.never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 - # empty delimiter '' will evaluate False + return ["".join(x) for x in output] + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + def whitespace_tokenize(self, text): + """Runs basic whitespace cleaning and splitting on a peice of text.""" + text = text.strip() + if not text: + return [] if self.delimiter == '': - symbols = line + tokens = text else: - symbols = line.split(self.delimiter) + tokens = text.split(self.delimiter) + return tokens + + def tokenize(self, line, add_eos=False, add_double_eos=False): + line = self._clean_text(line) + line = line.strip() + + symbols = self.whitespace_tokenize(line) + + split_symbols = [] + for symbol in symbols: + if self.lower_case and symbol not in self.never_split: + symbol = symbol.lower() + symbol = self._run_strip_accents(symbol) + split_symbols.extend(self._run_split_on_punc(symbol)) if add_double_eos: # lm1b - return [''] + symbols + [''] + return [''] + split_symbols + [''] elif add_eos: - return symbols + [''] + return split_symbols + [''] else: - return symbols + return split_symbols class LMOrderedIterator(object): @@ -556,3 +628,42 @@ def get_lm_corpus(datadir, dataset): torch.save(corpus, fn) return corpus + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False diff --git a/tests/tokenization_openai_test.py b/tests/tokenization_openai_test.py new file mode 100644 index 0000000000..dadcd9699a --- /dev/null +++ b/tests/tokenization_openai_test.py @@ -0,0 +1,57 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import unittest +import json +from io import open + +from pytorch_pretrained_bert.tokenization_openai import OpenAIGPTTokenizer + + +class OpenAIGPTTokenizationTest(unittest.TestCase): + + def test_full_tokenizer(self): + """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ + vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", + "w", "r", "t", + "lo", "low", "er", + "low", "lowest", "newer", "wider"] + vocab_tokens = dict(zip(vocab, range(len(vocab)))) + merges = ["#version: 0.2", "l o", "lo w", "e r", ""] + with open("/tmp/openai_tokenizer_vocab_test.json", "w", encoding='utf-8') as fp: + json.dump(vocab_tokens, fp) + vocab_file = fp.name + with open("/tmp/openai_tokenizer_merges_test.txt", "w", encoding='utf-8') as fp: + fp.write("\n".join(merges)) + merges_file = fp.name + + tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=[""]) + os.remove(vocab_file) + os.remove(merges_file) + + text = "lower" + bpe_tokens = ["low", "er"] + tokens = tokenizer.tokenize(text) + self.assertListEqual(tokens, bpe_tokens) + + input_tokens = tokens + [""] + input_bpe_tokens = [14, 15, 20] + self.assertListEqual( + tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tokenization_transfo_xl_test.py b/tests/tokenization_transfo_xl_test.py new file mode 100644 index 0000000000..9ff04f5f34 --- /dev/null +++ b/tests/tokenization_transfo_xl_test.py @@ -0,0 +1,90 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import unittest +from io import open + +from pytorch_pretrained_bert.tokenization_transfo_xl import (TransfoXLTokenizer, + _is_control, _is_punctuation, + _is_whitespace) + + +class TransfoXLTokenizationTest(unittest.TestCase): + + def test_full_tokenizer(self): + vocab_tokens = [ + "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", "," + ] + with open("/tmp/transfo_xl_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + vocab_file = vocab_writer.name + + tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True) + tokenizer.build_vocab() + os.remove(vocab_file) + + tokens = tokenizer.tokenize(u" UNwant\u00E9d,running") + self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) + + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) + + def test_full_tokenizer_lower(self): + tokenizer = TransfoXLTokenizer(lower_case=True) + + self.assertListEqual( + tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), + ["hello", "!", "how", "are", "you", "?"]) + self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) + + def test_full_tokenizer_no_lower(self): + tokenizer = TransfoXLTokenizer(lower_case=False) + + self.assertListEqual( + tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), + ["HeLLo", "!", "how", "Are", "yoU", "?"]) + + def test_is_whitespace(self): + self.assertTrue(_is_whitespace(u" ")) + self.assertTrue(_is_whitespace(u"\t")) + self.assertTrue(_is_whitespace(u"\r")) + self.assertTrue(_is_whitespace(u"\n")) + self.assertTrue(_is_whitespace(u"\u00A0")) + + self.assertFalse(_is_whitespace(u"A")) + self.assertFalse(_is_whitespace(u"-")) + + def test_is_control(self): + self.assertTrue(_is_control(u"\u0005")) + + self.assertFalse(_is_control(u"A")) + self.assertFalse(_is_control(u" ")) + self.assertFalse(_is_control(u"\t")) + self.assertFalse(_is_control(u"\r")) + + def test_is_punctuation(self): + self.assertTrue(_is_punctuation(u"-")) + self.assertTrue(_is_punctuation(u"$")) + self.assertTrue(_is_punctuation(u"`")) + self.assertTrue(_is_punctuation(u".")) + + self.assertFalse(_is_punctuation(u"A")) + self.assertFalse(_is_punctuation(u" ")) + + +if __name__ == '__main__': + unittest.main()