From b7175a27018425d46a01b9e5a0595e6e9b1ab6a1 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 9 Sep 2019 11:04:03 +0200 Subject: [PATCH] fixed imports in tests and gpt2 config test --- pytorch_transformers/__init__.py | 2 +- pytorch_transformers/modeling_tf_gpt2.py | 1 + pytorch_transformers/modeling_tf_utils.py | 5 +++-- pytorch_transformers/modeling_xlnet.py | 6 ++++-- pytorch_transformers/tests/modeling_auto_test.py | 6 ++++-- pytorch_transformers/tests/modeling_bert_test.py | 4 ++-- pytorch_transformers/tests/modeling_common_test.py | 6 ++++-- .../tests/modeling_distilbert_test.py | 4 ++-- pytorch_transformers/tests/modeling_gpt2_test.py | 4 ++-- pytorch_transformers/tests/modeling_openai_test.py | 4 ++-- pytorch_transformers/tests/modeling_roberta_test.py | 4 ++-- pytorch_transformers/tests/modeling_tf_auto_test.py | 6 ++++-- pytorch_transformers/tests/modeling_tf_bert_test.py | 4 ++-- pytorch_transformers/tests/modeling_tf_common_test.py | 6 ++++-- pytorch_transformers/tests/modeling_tf_gpt2_test.py | 11 ++++++----- .../tests/modeling_transfo_xl_test.py | 4 ++-- pytorch_transformers/tests/modeling_xlm_test.py | 4 ++-- pytorch_transformers/tests/modeling_xlnet_test.py | 4 ++-- pytorch_transformers/tests/optimization_test.py | 4 ++-- .../tests/tokenization_transfo_xl_test.py | 4 ++-- 20 files changed, 53 insertions(+), 40 deletions(-) diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index f13457f073..7bc7ddf7e1 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -89,7 +89,7 @@ try: import tensorflow as tf assert int(tf.__version__[0]) >= 2 _tf_available = True # pylint: disable=invalid-name -except ImportError: +except (ImportError, AssertionError): _tf_available = False # pylint: disable=invalid-name if _tf_available: diff --git a/pytorch_transformers/modeling_tf_gpt2.py b/pytorch_transformers/modeling_tf_gpt2.py index 27e51fb752..bcb9f5309a 100644 --- a/pytorch_transformers/modeling_tf_gpt2.py +++ b/pytorch_transformers/modeling_tf_gpt2.py @@ -699,6 +699,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): head_mask = inputs.get('head_mask', None) assert len(inputs) <= 5, "Too many inputs." + assert len(shape_list(input_ids)) == 3, "Inputs should have 3 dimensions: batch, choices, sequence length" num_choices = shape_list(input_ids)[1] seq_length = shape_list(input_ids)[2] diff --git a/pytorch_transformers/modeling_tf_utils.py b/pytorch_transformers/modeling_tf_utils.py index 9dfd42c36b..af67a8442e 100644 --- a/pytorch_transformers/modeling_tf_utils.py +++ b/pytorch_transformers/modeling_tf_utils.py @@ -313,7 +313,7 @@ class TFSequenceSummary(tf.keras.layers.Layer): # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError - self.summary = tf.keras.layers.Identity(name='summary') + self.summary = None if hasattr(config, 'summary_use_proj') and config.summary_use_proj: if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels @@ -372,7 +372,8 @@ class TFSequenceSummary(tf.keras.layers.Layer): if training and self.first_dropout is not None: output = self.first_dropout(output) - output = self.summary(output) + if self.summary is not None: + output = self.summary(output) if self.activation is not None: output = self.activation(output) diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index 00c15080a1..97feaad371 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -525,8 +525,10 @@ XLNET_INPUTS_DOCSTRING = r""" Only used during pretraining for partial prediction or for sequential decoding (generation). **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: A parallel sequence of tokens (can be used to indicate various portions of the inputs). - The embeddings from these tokens will be summed with the respective token embeddings. - Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices). + The type indices in XLNet are NOT selected in the vocabulary, they can be arbitrary numbers and + the important thing is that they should be different for tokens which belong to different segments. + The model will compute relative segment differences from the given type indices: + 0 if the segment id of two tokens are the same, 1 if not. **input_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for real tokens and 1 for padding. diff --git a/pytorch_transformers/tests/modeling_auto_test.py b/pytorch_transformers/tests/modeling_auto_test.py index 169f722ed7..4b00891c38 100644 --- a/pytorch_transformers/tests/modeling_auto_test.py +++ b/pytorch_transformers/tests/modeling_auto_test.py @@ -21,7 +21,9 @@ import shutil import pytest import logging -try: +from pytorch_transformers import is_torch_available + +if is_torch_available(): from pytorch_transformers import (AutoConfig, BertConfig, AutoModel, BertModel, AutoModelWithLMHead, BertForMaskedLM, @@ -31,7 +33,7 @@ try: from .modeling_common_test import (CommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -except ImportError: +else: pytestmark = pytest.mark.skip("Require Torch") diff --git a/pytorch_transformers/tests/modeling_bert_test.py b/pytorch_transformers/tests/modeling_bert_test.py index d63d1b407c..cac1f996e9 100644 --- a/pytorch_transformers/tests/modeling_bert_test.py +++ b/pytorch_transformers/tests/modeling_bert_test.py @@ -25,13 +25,13 @@ from pytorch_transformers import is_torch_available from .modeling_common_test import (CommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -try: +if is_torch_available(): from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM, BertForNextSentencePrediction, BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, BertForMultipleChoice) from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP -except ImportError: +else: pytestmark = pytest.mark.skip("Require Torch") diff --git a/pytorch_transformers/tests/modeling_common_test.py b/pytorch_transformers/tests/modeling_common_test.py index 1f778d608f..a7d56041a3 100644 --- a/pytorch_transformers/tests/modeling_common_test.py +++ b/pytorch_transformers/tests/modeling_common_test.py @@ -27,13 +27,15 @@ import unittest import logging import pytest -try: +from pytorch_transformers import is_torch_available + +if is_torch_available(): import torch from pytorch_transformers import (PretrainedConfig, PreTrainedModel, BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) -except ImportError: +else: pytestmark = pytest.mark.skip("Require Torch") diff --git a/pytorch_transformers/tests/modeling_distilbert_test.py b/pytorch_transformers/tests/modeling_distilbert_test.py index 10bb4bb398..8fef9d5833 100644 --- a/pytorch_transformers/tests/modeling_distilbert_test.py +++ b/pytorch_transformers/tests/modeling_distilbert_test.py @@ -21,10 +21,10 @@ import pytest from pytorch_transformers import is_torch_available -try: +if is_torch_available(): from pytorch_transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification) -except ImportError: +else: pytestmark = pytest.mark.skip("Require Torch") from .modeling_common_test import (CommonTestCases, ids_tensor) diff --git a/pytorch_transformers/tests/modeling_gpt2_test.py b/pytorch_transformers/tests/modeling_gpt2_test.py index e8f2ff20d2..dc7c0d1816 100644 --- a/pytorch_transformers/tests/modeling_gpt2_test.py +++ b/pytorch_transformers/tests/modeling_gpt2_test.py @@ -22,10 +22,10 @@ import shutil from pytorch_transformers import is_torch_available -try: +if is_torch_available(): from pytorch_transformers import (GPT2Config, GPT2Model, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2DoubleHeadsModel) -except ImportError: +else: pytestmark = pytest.mark.skip("Require Torch") from .modeling_common_test import (CommonTestCases, ids_tensor) diff --git a/pytorch_transformers/tests/modeling_openai_test.py b/pytorch_transformers/tests/modeling_openai_test.py index b89990a181..6df4406d03 100644 --- a/pytorch_transformers/tests/modeling_openai_test.py +++ b/pytorch_transformers/tests/modeling_openai_test.py @@ -22,10 +22,10 @@ import shutil from pytorch_transformers import is_torch_available -try: +if is_torch_available(): from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) -except ImportError: +else: pytestmark = pytest.mark.skip("Require Torch") from .modeling_common_test import (CommonTestCases, ids_tensor) diff --git a/pytorch_transformers/tests/modeling_roberta_test.py b/pytorch_transformers/tests/modeling_roberta_test.py index ed0f8b5cdc..11f2893671 100644 --- a/pytorch_transformers/tests/modeling_roberta_test.py +++ b/pytorch_transformers/tests/modeling_roberta_test.py @@ -22,11 +22,11 @@ import pytest from pytorch_transformers import is_torch_available -try: +if is_torch_available(): import torch from pytorch_transformers import (RobertaConfig, RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification) from pytorch_transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP -except ImportError: +else: pytestmark = pytest.mark.skip("Require Torch") from .modeling_common_test import (CommonTestCases, ids_tensor) diff --git a/pytorch_transformers/tests/modeling_tf_auto_test.py b/pytorch_transformers/tests/modeling_tf_auto_test.py index 816d6c1b1a..4617ab817b 100644 --- a/pytorch_transformers/tests/modeling_tf_auto_test.py +++ b/pytorch_transformers/tests/modeling_tf_auto_test.py @@ -21,7 +21,9 @@ import shutil import pytest import logging -try: +from pytorch_transformers import is_tf_available + +if is_tf_available(): from pytorch_transformers import (AutoConfig, BertConfig, TFAutoModel, TFBertModel, TFAutoModelWithLMHead, TFBertForMaskedLM, @@ -31,7 +33,7 @@ try: from .modeling_common_test import (CommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester -except ImportError: +else: pytestmark = pytest.mark.skip("Require TensorFlow") diff --git a/pytorch_transformers/tests/modeling_tf_bert_test.py b/pytorch_transformers/tests/modeling_tf_bert_test.py index c95e33d780..55bbe36feb 100644 --- a/pytorch_transformers/tests/modeling_tf_bert_test.py +++ b/pytorch_transformers/tests/modeling_tf_bert_test.py @@ -26,7 +26,7 @@ from .configuration_common_test import ConfigTester from pytorch_transformers import BertConfig, is_tf_available -try: +if is_tf_available(): import tensorflow as tf from pytorch_transformers.modeling_tf_bert import (TFBertModel, TFBertForMaskedLM, TFBertForNextSentencePrediction, @@ -36,7 +36,7 @@ try: TFBertForTokenClassification, TFBertForQuestionAnswering, TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP) -except ImportError: +else: pytestmark = pytest.mark.skip("Require TensorFlow") diff --git a/pytorch_transformers/tests/modeling_tf_common_test.py b/pytorch_transformers/tests/modeling_tf_common_test.py index f9b87eed9a..da3263ffde 100644 --- a/pytorch_transformers/tests/modeling_tf_common_test.py +++ b/pytorch_transformers/tests/modeling_tf_common_test.py @@ -25,11 +25,13 @@ import uuid import pytest import sys -try: +from pytorch_transformers import is_tf_available + +if is_tf_available(): import tensorflow as tf from pytorch_transformers import TFPreTrainedModel # from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP -except ImportError: +else: pytestmark = pytest.mark.skip("Require TensorFlow") diff --git a/pytorch_transformers/tests/modeling_tf_gpt2_test.py b/pytorch_transformers/tests/modeling_tf_gpt2_test.py index 2710488169..5fef1f6453 100644 --- a/pytorch_transformers/tests/modeling_tf_gpt2_test.py +++ b/pytorch_transformers/tests/modeling_tf_gpt2_test.py @@ -26,19 +26,20 @@ from .configuration_common_test import ConfigTester from pytorch_transformers import GPT2Config, is_tf_available -try: +if is_tf_available(): import tensorflow as tf from pytorch_transformers.modeling_tf_gpt2 import (TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel, TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) -except ImportError: +else: pytestmark = pytest.mark.skip("Require TensorFlow") class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester): - all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, - TFGPT2DoubleHeadsModel) if is_tf_available() else () + # all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, + # TFGPT2DoubleHeadsModel) if is_tf_available() else () + all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else () class TFGPT2ModelTester(object): @@ -186,7 +187,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester): def setUp(self): self.model_tester = TFGPT2ModelTest.TFGPT2ModelTester(self) - self.config_tester = ConfigTester(self, config_class=GPT2Config, hidden_size=37) + self.config_tester = ConfigTester(self, config_class=GPT2Config, n_embd=37) def test_config(self): self.config_tester.run_common_tests() diff --git a/pytorch_transformers/tests/modeling_transfo_xl_test.py b/pytorch_transformers/tests/modeling_transfo_xl_test.py index 9a72335157..035f20991f 100644 --- a/pytorch_transformers/tests/modeling_transfo_xl_test.py +++ b/pytorch_transformers/tests/modeling_transfo_xl_test.py @@ -23,11 +23,11 @@ import pytest from pytorch_transformers import is_torch_available -try: +if is_torch_available(): import torch from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP -except ImportError: +else: pytestmark = pytest.mark.skip("Require Torch") from .modeling_common_test import (CommonTestCases, ids_tensor) diff --git a/pytorch_transformers/tests/modeling_xlm_test.py b/pytorch_transformers/tests/modeling_xlm_test.py index 21cf624b9b..acc8a68066 100644 --- a/pytorch_transformers/tests/modeling_xlm_test.py +++ b/pytorch_transformers/tests/modeling_xlm_test.py @@ -22,11 +22,11 @@ import pytest from pytorch_transformers import is_torch_available -try: +if is_torch_available(): from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification) from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP -except ImportError: +else: pytestmark = pytest.mark.skip("Require Torch") from .modeling_common_test import (CommonTestCases, ids_tensor) diff --git a/pytorch_transformers/tests/modeling_xlnet_test.py b/pytorch_transformers/tests/modeling_xlnet_test.py index b280ed4592..8c8b3f964f 100644 --- a/pytorch_transformers/tests/modeling_xlnet_test.py +++ b/pytorch_transformers/tests/modeling_xlnet_test.py @@ -25,12 +25,12 @@ import pytest from pytorch_transformers import is_torch_available -try: +if is_torch_available(): import torch from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering) from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP -except ImportError: +else: pytestmark = pytest.mark.skip("Require Torch") from .modeling_common_test import (CommonTestCases, ids_tensor) diff --git a/pytorch_transformers/tests/optimization_test.py b/pytorch_transformers/tests/optimization_test.py index 07dc22141d..c1c6270f32 100644 --- a/pytorch_transformers/tests/optimization_test.py +++ b/pytorch_transformers/tests/optimization_test.py @@ -22,12 +22,12 @@ import pytest from pytorch_transformers import is_torch_available -try: +if is_torch_available(): import torch from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) -except ImportError: +else: pytestmark = pytest.mark.skip("Require Torch") from .tokenization_tests_commons import TemporaryDirectory diff --git a/pytorch_transformers/tests/tokenization_transfo_xl_test.py b/pytorch_transformers/tests/tokenization_transfo_xl_test.py index 792033d82c..563871f690 100644 --- a/pytorch_transformers/tests/tokenization_transfo_xl_test.py +++ b/pytorch_transformers/tests/tokenization_transfo_xl_test.py @@ -21,10 +21,10 @@ from io import open from pytorch_transformers import is_torch_available -try: +if is_torch_available(): import torch from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES -except ImportError: +else: pytestmark = pytest.mark.skip("Require Torch") # TODO: untangle Transfo-XL tokenizer from torch.load and torch.save from .tokenization_tests_commons import CommonTestCases