diff --git a/docs/source/index.rst b/docs/source/index.rst index b80fd8437b..b613596331 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -40,6 +40,7 @@ The library currently contains PyTorch implementations, pre-trained model weight :maxdepth: 2 :caption: Package Reference + model_doc/auto model_doc/bert model_doc/gpt model_doc/transformerxl diff --git a/docs/source/model_doc/auto.rst b/docs/source/model_doc/auto.rst index ad439fff03..43f6e103bd 100644 --- a/docs/source/model_doc/auto.rst +++ b/docs/source/model_doc/auto.rst @@ -1,9 +1,15 @@ -AutoModel, AutoConfig and AutoTokenizer - Standard derived classes ---------------------------------------------------------------- +AutoModels +----------- -In many case, the architecture you want to use can be guessed from the name or the path of the pretrained model you are supplying to the ``from_pretrained`` method. +In many cases, the architecture you want to use can be guessed from the name or the path of the pretrained model you are supplying to the ``from_pretrained`` method. + +AutoClasses are here to do this job for you so that you automatically retreive the relevant model given the name/path to the pretrained weights/config/vocabulary. + +There are two types of AutoClasses: + +- ``AutoModel``, ``AutoConfig`` and ``AutoTokenizer``: instantiating these ones will directly create a class of the relevant architecture (ex: ``model = AutoModel.from_pretrained('bert-base-cased')`` will create a instance of ``BertModel``) +- All the others (``AutoModelWithLMHead``, ``AutoModelForSequenceClassification``...) are standardized Auto classes for finetuning. Instantiating these will create instance of the same class (``AutoModelWithLMHead``, ``AutoModelForSequenceClassification``...) comprising (i) the relevant base model class (as mentioned just above) and (ii) a standard fine-tuning head on top, convenient for the task. -Auto classes are here to do this job for you so that you automatically retreive the relevant model given the name/path to the pretrained weights/config/vocabulary. ``AutoConfig`` ~~~~~~~~~~~~~~~~~~~~~ @@ -19,6 +25,20 @@ Auto classes are here to do this job for you so that you automatically retreive :members: +``AutoModelWithLMHead`` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_transformers.AutoModelWithLMHead + :members: + + +``AutoModelForSequenceClassification`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_transformers.AutoModelForSequenceClassification + :members: + + ``AutoTokenizer`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index d4ddda94fa..110c3dc3c7 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -1,4 +1,5 @@ __version__ = "1.0.0" +from .tokenization_auto import AutoTokenizer from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer from .tokenization_openai import OpenAIGPTTokenizer from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) diff --git a/pytorch_transformers/tests/tokenization_auto_test.py b/pytorch_transformers/tests/tokenization_auto_test.py new file mode 100644 index 0000000000..f4f82083f2 --- /dev/null +++ b/pytorch_transformers/tests/tokenization_auto_test.py @@ -0,0 +1,46 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +import shutil +import pytest +import logging + +from pytorch_transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer +from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP +from pytorch_transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP + + +class AutoTokenizerTest(unittest.TestCase): + def test_tokenizer_from_pretrained(self): + logging.basicConfig(level=logging.INFO) + for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + tokenizer = AutoTokenizer.from_pretrained(model_name) + self.assertIsNotNone(tokenizer) + self.assertIsInstance(tokenizer, BertTokenizer) + self.assertGreater(len(tokenizer), 0) + + for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + tokenizer = AutoTokenizer.from_pretrained(model_name) + self.assertIsNotNone(tokenizer) + self.assertIsInstance(tokenizer, GPT2Tokenizer) + self.assertGreater(len(tokenizer), 0) + + +if __name__ == "__main__": + unittest.main()