From ed4e5422604b04df823eb2011e9ed4d766cf9980 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 5 Aug 2019 18:14:07 +0200 Subject: [PATCH] adding tests --- pytorch_transformers/__init__.py | 2 + pytorch_transformers/modeling_auto.py | 27 ++++++++- pytorch_transformers/modeling_utils.py | 2 +- .../tests/modeling_auto_test.py | 55 +++++++++++++++++++ 4 files changed, 83 insertions(+), 3 deletions(-) create mode 100644 pytorch_transformers/tests/modeling_auto_test.py diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index 72d666448e..d4ddda94fa 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -7,6 +7,8 @@ from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE from .tokenization_xlm import XLMTokenizer from .tokenization_utils import (PreTrainedTokenizer) +from .modeling_auto import (AutoConfig, AutoModel, AutoModelForSequenceClassification, AutoModelWithLMHead) + from .modeling_bert import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, BertForSequenceClassification, BertForMultipleChoice, diff --git a/pytorch_transformers/modeling_auto.py b/pytorch_transformers/modeling_auto.py index 7d3ea7ec60..22a35090aa 100644 --- a/pytorch_transformers/modeling_auto.py +++ b/pytorch_transformers/modeling_auto.py @@ -393,6 +393,8 @@ class AutoModelWithLMHead(DerivedAutoModel): def __init__(self, base_model): super(AutoModelWithLMHead, self).__init__(base_model) + config = base_model.config + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.apply(self.init_weights) @@ -426,6 +428,17 @@ class AutoModelWithLMHead(DerivedAutoModel): return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions) +AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS = { + 'num_labels': 2, + 'summary_type': 'first', + 'summary_use_proj': True, + 'summary_activation': None, + 'summary_proj_to_labels': True, + 'summary_first_dropout': 0.1 +} + + + class AutoModelForSequenceClassification(DerivedAutoModel): r""" :class:`~pytorch_transformers.AutoModelForSequenceClassification` is a class for sequence classification @@ -451,8 +464,18 @@ class AutoModelForSequenceClassification(DerivedAutoModel): def __init__(self, base_model): super(AutoModelForSequenceClassification, self).__init__(base_model) - self.num_labels = base_model.config.num_labels - self.sequence_summary = SequenceSummary(base_model.config) + # Complete configuration with defaults if necessary + config = base_model.config + for key, value in AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS.items(): + if not hasattr(config, key): + setattr(config, key, value) + + # Update base model and derived model config + self.transformer.config = config + self.config = config + + self.num_labels = config.num_labels + self.sequence_summary = SequenceSummary(config) self.apply(self.init_weights) diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 2664c542e0..f832b482af 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -777,7 +777,7 @@ class SequenceSummary(nn.Module): super(SequenceSummary, self).__init__() self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last' - if config.summary_type == 'attn': + if self.summary_type == 'attn': # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 diff --git a/pytorch_transformers/tests/modeling_auto_test.py b/pytorch_transformers/tests/modeling_auto_test.py new file mode 100644 index 0000000000..07042a255c --- /dev/null +++ b/pytorch_transformers/tests/modeling_auto_test.py @@ -0,0 +1,55 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +import shutil +import pytest +import logging + +from pytorch_transformers import AutoConfig, BertConfig, AutoModel, BertModel, AutoModelForSequenceClassification, AutoModelWithLMHead +from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP + +from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor) + + +class AutoModelTest(unittest.TestCase): + def test_model_from_pretrained(self): + logging.basicConfig(level=logging.INFO) + for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + config = AutoConfig.from_pretrained(model_name) + self.assertIsNotNone(config) + self.assertIsInstance(config, BertConfig) + + model = AutoModel.from_pretrained(model_name) + model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True) + self.assertIsNotNone(model) + self.assertIsInstance(model, BertModel) + for value in loading_info.values(): + self.assertEqual(len(value), 0) + + model = AutoModelForSequenceClassification.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertIsInstance(getattr(model, model.base_model_prefix), BertModel) + + model = AutoModelWithLMHead.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertIsInstance(getattr(model, model.base_model_prefix), BertModel) + + +if __name__ == "__main__": + unittest.main()