added tests for OpenAI GPT and Transformer-XL tokenizers

This commit is contained in:
thomwolf
2019-02-11 10:17:16 +01:00
parent 9bdcba53fd
commit b514a60c36
5 changed files with 286 additions and 25 deletions

View File

@@ -529,10 +529,10 @@ This model *outputs*:
`OpenAIGPTDoubleHeadsModel` includes the `OpenAIGPTModel` Transformer followed by two heads: `OpenAIGPTDoubleHeadsModel` includes the `OpenAIGPTModel` Transformer followed by two heads:
- a language modeling head with weights tied to the input embeddings (no additional parameters) and: - a language modeling head with weights tied to the input embeddings (no additional parameters) and:
- a multiple choice classifier (linear layer). - a multiple choice classifier (linear layer that take as input a hidden state in a sequence to compute a score, see details in paper).
*Inputs* are the same as the inputs of the [`OpenAIGPTModel`](#-9.-`OpenAIGPTModel`) class plus a classification mask and two optional labels: *Inputs* are the same as the inputs of the [`OpenAIGPTModel`](#-9.-`OpenAIGPTModel`) class plus a classification mask and two optional labels:
- `multiple_choice_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise. - `multiple_choice_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token whose hidden state should be used as input for the multiple choice classifier (usually the [CLS] token for each choice).
- `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size]. - `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, ..., vocab_size].
- `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_choices]. - `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_choices].
@@ -613,9 +613,9 @@ Please refer to the doc strings and code in [`tokenization_openai.py`](./pytorch
#### `TransfoXLTokenizer` #### `TransfoXLTokenizer`
`TransfoXLTokenizer` perform word tokenization. `TransfoXLTokenizer` perform word tokenization. This tokenizer can be used for adaptive softmax and has utilities for counting tokens in a corpus to create a vocabulary ordered by toekn frequency (for adaptive softmax). See the adaptive softmax paper ([Efficient softmax approximation for GPUs](http://arxiv.org/abs/1609.04309)) for more details.
Please refer to the doc strings and code in [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) for the details of the `TransfoXLTokenizer`. Please refer to the doc strings and code in [`tokenization_transfo_xl.py`](./pytorch_pretrained_bert/tokenization_transfo_xl.py) for the details of these additional methods in `TransfoXLTokenizer`.
### Optimizers: ### Optimizers:

View File

@@ -70,7 +70,10 @@ def text_standardize(text):
class OpenAIGPTTokenizer(object): class OpenAIGPTTokenizer(object):
""" """
mostly a wrapper for a public python bpe tokenizer BPE tokenizer. Peculiarities:
- lower case all inputs
- uses SpaCy tokenizer
- special tokens: additional symbols (ex: "__classify__") to add to a vocabulary.
""" """
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):

View File

@@ -25,6 +25,7 @@ import os
import sys import sys
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
from io import open from io import open
import unicodedata
import torch import torch
import numpy as np import numpy as np
@@ -89,8 +90,8 @@ class TransfoXLTokenizer(object):
tokenizer.__dict__[key] = value tokenizer.__dict__[key] = value
return tokenizer return tokenizer
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False,
delimiter=None, vocab_file=None): delimiter=None, vocab_file=None, never_split=("<unk>", "<eos>", "<formula>")):
self.counter = Counter() self.counter = Counter()
self.special = special self.special = special
self.min_freq = min_freq self.min_freq = min_freq
@@ -98,6 +99,7 @@ class TransfoXLTokenizer(object):
self.lower_case = lower_case self.lower_case = lower_case
self.delimiter = delimiter self.delimiter = delimiter
self.vocab_file = vocab_file self.vocab_file = vocab_file
self.never_split = never_split
def count_file(self, path, verbose=False, add_eos=False): def count_file(self, path, verbose=False, add_eos=False):
if verbose: print('counting file {} ...'.format(path)) if verbose: print('counting file {} ...'.format(path))
@@ -132,7 +134,12 @@ class TransfoXLTokenizer(object):
for line in f: for line in f:
symb = line.strip().split()[0] symb = line.strip().split()[0]
self.add_symbol(symb) self.add_symbol(symb)
if '<UNK>' in self.sym2idx:
self.unk_idx = self.sym2idx['<UNK>'] self.unk_idx = self.sym2idx['<UNK>']
elif '<unk>' in self.sym2idx:
self.unk_idx = self.sym2idx['<unk>']
else:
raise ValueError('No <unkown> token in vocabulary')
def build_vocab(self): def build_vocab(self):
if self.vocab_file: if self.vocab_file:
@@ -198,7 +205,7 @@ class TransfoXLTokenizer(object):
self.sym2idx[sym] = len(self.idx2sym) - 1 self.sym2idx[sym] = len(self.idx2sym) - 1
def get_sym(self, idx): def get_sym(self, idx):
assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx)
return self.idx2sym[idx] return self.idx2sym[idx]
def get_idx(self, sym): def get_idx(self, sym):
@@ -206,9 +213,16 @@ class TransfoXLTokenizer(object):
return self.sym2idx[sym] return self.sym2idx[sym]
else: else:
# print('encounter unk {}'.format(sym)) # print('encounter unk {}'.format(sym))
assert '<eos>' not in sym # assert '<eos>' not in sym
assert hasattr(self, 'unk_idx') if hasattr(self, 'unk_idx'):
return self.sym2idx.get(sym, self.unk_idx) return self.sym2idx.get(sym, self.unk_idx)
# Backward compatibility with pre-trained models
elif '<unk>' in self.sym2idx:
return self.sym2idx['<unk>']
elif '<UNK>' in self.sym2idx:
return self.sym2idx['<UNK>']
else:
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
def convert_ids_to_tokens(self, indices): def convert_ids_to_tokens(self, indices):
"""Converts a sequence of indices in symbols using the vocab.""" """Converts a sequence of indices in symbols using the vocab."""
@@ -231,24 +245,82 @@ class TransfoXLTokenizer(object):
def __len__(self): def __len__(self):
return len(self.idx2sym) return len(self.idx2sym)
def tokenize(self, line, add_eos=False, add_double_eos=False): def _run_split_on_punc(self, text):
line = line.strip() """Splits punctuation on a piece of text."""
# convert to lower case if text in self.never_split:
if self.lower_case: return [text]
line = line.lower() chars = list(text)
i = 0
# empty delimiter '' will evaluate False start_new_word = True
if self.delimiter == '': output = []
symbols = line while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else: else:
symbols = line.split(self.delimiter) if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
def whitespace_tokenize(self, text):
"""Runs basic whitespace cleaning and splitting on a peice of text."""
text = text.strip()
if not text:
return []
if self.delimiter == '':
tokens = text
else:
tokens = text.split(self.delimiter)
return tokens
def tokenize(self, line, add_eos=False, add_double_eos=False):
line = self._clean_text(line)
line = line.strip()
symbols = self.whitespace_tokenize(line)
split_symbols = []
for symbol in symbols:
if self.lower_case and symbol not in self.never_split:
symbol = symbol.lower()
symbol = self._run_strip_accents(symbol)
split_symbols.extend(self._run_split_on_punc(symbol))
if add_double_eos: # lm1b if add_double_eos: # lm1b
return ['<S>'] + symbols + ['<S>'] return ['<S>'] + split_symbols + ['<S>']
elif add_eos: elif add_eos:
return symbols + ['<eos>'] return split_symbols + ['<eos>']
else: else:
return symbols return split_symbols
class LMOrderedIterator(object): class LMOrderedIterator(object):
@@ -556,3 +628,42 @@ def get_lm_corpus(datadir, dataset):
torch.save(corpus, fn) torch.save(corpus, fn)
return corpus return corpus
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False

View File

@@ -0,0 +1,57 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
import json
from io import open
from pytorch_pretrained_bert.tokenization_openai import OpenAIGPTTokenizer
class OpenAIGPTTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>",
"low</w>", "lowest</w>", "newer</w>", "wider</w>"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w", encoding='utf-8') as fp:
json.dump(vocab_tokens, fp)
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w", encoding='utf-8') as fp:
fp.write("\n".join(merges))
merges_file = fp.name
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>"])
os.remove(vocab_file)
os.remove(merges_file)
text = "lower"
bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,90 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import unittest
from io import open
from pytorch_pretrained_bert.tokenization_transfo_xl import (TransfoXLTokenizer,
_is_control, _is_punctuation,
_is_whitespace)
class TransfoXLTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
vocab_tokens = [
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", ","
]
with open("/tmp/transfo_xl_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
tokenizer.build_vocab()
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwant\u00E9d,running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
def test_full_tokenizer_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=True)
self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["hello", "!", "how", "are", "you", "?"])
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
def test_full_tokenizer_no_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=False)
self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
["HeLLo", "!", "how", "Are", "yoU", "?"])
def test_is_whitespace(self):
self.assertTrue(_is_whitespace(u" "))
self.assertTrue(_is_whitespace(u"\t"))
self.assertTrue(_is_whitespace(u"\r"))
self.assertTrue(_is_whitespace(u"\n"))
self.assertTrue(_is_whitespace(u"\u00A0"))
self.assertFalse(_is_whitespace(u"A"))
self.assertFalse(_is_whitespace(u"-"))
def test_is_control(self):
self.assertTrue(_is_control(u"\u0005"))
self.assertFalse(_is_control(u"A"))
self.assertFalse(_is_control(u" "))
self.assertFalse(_is_control(u"\t"))
self.assertFalse(_is_control(u"\r"))
def test_is_punctuation(self):
self.assertTrue(_is_punctuation(u"-"))
self.assertTrue(_is_punctuation(u"$"))
self.assertTrue(_is_punctuation(u"`"))
self.assertTrue(_is_punctuation(u"."))
self.assertFalse(_is_punctuation(u"A"))
self.assertFalse(_is_punctuation(u" "))
if __name__ == '__main__':
unittest.main()