adding tests
This commit is contained in:
@@ -7,6 +7,8 @@ from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
|
||||
from .tokenization_xlm import XLMTokenizer
|
||||
from .tokenization_utils import (PreTrainedTokenizer)
|
||||
|
||||
from .modeling_auto import (AutoConfig, AutoModel, AutoModelForSequenceClassification, AutoModelWithLMHead)
|
||||
|
||||
from .modeling_bert import (BertConfig, BertModel, BertForPreTraining,
|
||||
BertForMaskedLM, BertForNextSentencePrediction,
|
||||
BertForSequenceClassification, BertForMultipleChoice,
|
||||
|
||||
@@ -393,6 +393,8 @@ class AutoModelWithLMHead(DerivedAutoModel):
|
||||
|
||||
def __init__(self, base_model):
|
||||
super(AutoModelWithLMHead, self).__init__(base_model)
|
||||
config = base_model.config
|
||||
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
@@ -426,6 +428,17 @@ class AutoModelWithLMHead(DerivedAutoModel):
|
||||
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
|
||||
|
||||
|
||||
AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS = {
|
||||
'num_labels': 2,
|
||||
'summary_type': 'first',
|
||||
'summary_use_proj': True,
|
||||
'summary_activation': None,
|
||||
'summary_proj_to_labels': True,
|
||||
'summary_first_dropout': 0.1
|
||||
}
|
||||
|
||||
|
||||
|
||||
class AutoModelForSequenceClassification(DerivedAutoModel):
|
||||
r"""
|
||||
:class:`~pytorch_transformers.AutoModelForSequenceClassification` is a class for sequence classification
|
||||
@@ -451,8 +464,18 @@ class AutoModelForSequenceClassification(DerivedAutoModel):
|
||||
|
||||
def __init__(self, base_model):
|
||||
super(AutoModelForSequenceClassification, self).__init__(base_model)
|
||||
self.num_labels = base_model.config.num_labels
|
||||
self.sequence_summary = SequenceSummary(base_model.config)
|
||||
# Complete configuration with defaults if necessary
|
||||
config = base_model.config
|
||||
for key, value in AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS.items():
|
||||
if not hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
# Update base model and derived model config
|
||||
self.transformer.config = config
|
||||
self.config = config
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.sequence_summary = SequenceSummary(config)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
|
||||
@@ -777,7 +777,7 @@ class SequenceSummary(nn.Module):
|
||||
super(SequenceSummary, self).__init__()
|
||||
|
||||
self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last'
|
||||
if config.summary_type == 'attn':
|
||||
if self.summary_type == 'attn':
|
||||
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
||||
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
||||
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||
|
||||
55
pytorch_transformers/tests/modeling_auto_test.py
Normal file
55
pytorch_transformers/tests/modeling_auto_test.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
from pytorch_transformers import AutoConfig, BertConfig, AutoModel, BertModel, AutoModelForSequenceClassification, AutoModelWithLMHead
|
||||
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
|
||||
|
||||
|
||||
class AutoModelTest(unittest.TestCase):
|
||||
def test_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
model = AutoModel.from_pretrained(model_name)
|
||||
model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, BertModel)
|
||||
for value in loading_info.values():
|
||||
self.assertEqual(len(value), 0)
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(getattr(model, model.base_model_prefix), BertModel)
|
||||
|
||||
model = AutoModelWithLMHead.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(getattr(model, model.base_model_prefix), BertModel)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user