From 62df4ba59aac3a62a03f40b602f9c285ea282108 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 28 Aug 2019 12:22:56 +0200 Subject: [PATCH] add dilbert tokenizer and tests --- pytorch_transformers/__init__.py | 5 +- .../tests/tokenization_bert_test.py | 6 +- .../tests/tokenization_dilbert_test.py | 46 ++++++++++++++ pytorch_transformers/tokenization_dilbert.py | 62 +++++++++++++++++++ 4 files changed, 114 insertions(+), 5 deletions(-) create mode 100644 pytorch_transformers/tests/tokenization_dilbert_test.py create mode 100644 pytorch_transformers/tokenization_dilbert.py diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index e6774c96d8..22bc4d3c21 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -7,6 +7,7 @@ from .tokenization_gpt2 import GPT2Tokenizer from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE from .tokenization_xlm import XLMTokenizer from .tokenization_roberta import RobertaTokenizer +from .tokenization_dilbert import DilBertTokenizer from .tokenization_utils import (PreTrainedTokenizer) @@ -41,8 +42,8 @@ from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel, from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_dilbert import (DilBertConfig, DilBertForMaskedLM, DilBertModel, - DilBertForSequenceClassification, DilBertForQuestionAnswering, - DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) + DilBertForSequenceClassification, DilBertForQuestionAnswering, + DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) diff --git a/pytorch_transformers/tests/tokenization_bert_test.py b/pytorch_transformers/tests/tokenization_bert_test.py index db507317a8..aaca746d46 100644 --- a/pytorch_transformers/tests/tokenization_bert_test.py +++ b/pytorch_transformers/tests/tokenization_bert_test.py @@ -42,7 +42,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) def get_tokenizer(self): - return BertTokenizer.from_pretrained(self.tmpdirname) + return self.tokenizer_class.from_pretrained(self.tmpdirname) def get_input_output_texts(self): input_text = u"UNwant\u00E9d,running" @@ -50,7 +50,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): return input_text, output_text def test_full_tokenizer(self): - tokenizer = BertTokenizer(self.vocab_file) + tokenizer = self.tokenizer_class(self.vocab_file) tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) @@ -126,7 +126,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): self.assertFalse(_is_punctuation(u" ")) def test_sequence_builders(self): - tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased") text = tokenizer.encode("sequence builders") text_2 = tokenizer.encode("multi-sequence build") diff --git a/pytorch_transformers/tests/tokenization_dilbert_test.py b/pytorch_transformers/tests/tokenization_dilbert_test.py new file mode 100644 index 0000000000..4cc7aa6c88 --- /dev/null +++ b/pytorch_transformers/tests/tokenization_dilbert_test.py @@ -0,0 +1,46 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import unittest +from io import open + +from pytorch_transformers.tokenization_dilbert import (DilBertTokenizer) + +from .tokenization_tests_commons import CommonTestCases +from .tokenization_bert_test import BertTokenizationTest + +class DilBertTokenizationTest(BertTokenizationTest): + + tokenizer_class = DilBertTokenizer + + def get_tokenizer(self): + return DilBertTokenizer.from_pretrained(self.tmpdirname) + + def test_sequence_builders(self): + tokenizer = DilBertTokenizer.from_pretrained("dilbert-base-uncased") + + text = tokenizer.encode("sequence builders") + text_2 = tokenizer.encode("multi-sequence build") + + encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) + encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) + + assert encoded_sentence == [101] + text + [102] + assert encoded_pair == [101] + text + [102] + text_2 + [102] + +if __name__ == '__main__': + unittest.main() diff --git a/pytorch_transformers/tokenization_dilbert.py b/pytorch_transformers/tokenization_dilbert.py new file mode 100644 index 0000000000..8d71e1b486 --- /dev/null +++ b/pytorch_transformers/tokenization_dilbert.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for DilBERT.""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import collections +import logging +import os +import unicodedata +from io import open + +from .tokenization_bert import BertTokenizer + +logger = logging.getLogger(__name__) + +VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} + +PRETRAINED_VOCAB_FILES_MAP = { + 'vocab_file': + { + 'dilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", + 'dilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + 'dilbert-base-uncased': 512, + 'dilbert-base-uncased-distilled-squad': 512, +} + + +class DilBertTokenizer(BertTokenizer): + r""" + Constructs a DilBertTokenizer. + :class:`~pytorch_transformers.DilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece + + Args: + vocab_file: Path to a one-wordpiece-per-line vocabulary file + do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False + do_basic_tokenize: Whether to do basic tokenization before wordpiece. + max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the + minimum of this value (if specified) and the underlying BERT model's sequence length. + never_split: List of tokens which will never be split during tokenization. Only has an effect when + do_wordpiece_only=False + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES