Replace (TF)CommonTestCases for modeling with a mixin.
I suspect the wrapper classes were created in order to prevent the abstract base class (TF)CommonModelTester from being included in test discovery and running, because that would fail. I solved this by replacing the abstract base class with a mixin. Code changes are just de-indenting and automatic reformattings performed by black to use the extra line space.
This commit is contained in:
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import XxxConfig, is_tf_available
|
from transformers import XxxConfig, is_tf_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_tf, slow
|
from .utils import CACHE_DIR, require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -32,7 +34,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
|
class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import CommonTestCases, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -34,7 +36,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class XxxModelTest(CommonTestCases.CommonModelTester):
|
class XxxModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(XxxModel, XxxForMaskedLM, XxxForQuestionAnswering, XxxForSequenceClassification, XxxForTokenClassification)
|
(XxxModel, XxxForMaskedLM, XxxForQuestionAnswering, XxxForSequenceClassification, XxxForTokenClassification)
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import CommonTestCases, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -33,7 +35,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class AlbertModelTest(CommonTestCases.CommonModelTester):
|
class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else ()
|
all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else ()
|
||||||
|
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import CommonTestCases, floats_tensor, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -37,7 +39,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class BertModelTest(CommonTestCases.CommonModelTester):
|
class BertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -69,9 +69,8 @@ def _config_zero_init(config):
|
|||||||
return configs_no_init
|
return configs_no_init
|
||||||
|
|
||||||
|
|
||||||
class CommonTestCases:
|
@require_torch
|
||||||
@require_torch
|
class ModelTesterMixin:
|
||||||
class CommonModelTester(unittest.TestCase):
|
|
||||||
|
|
||||||
model_tester = None
|
model_tester = None
|
||||||
all_model_classes = ()
|
all_model_classes = ()
|
||||||
@@ -612,7 +611,8 @@ class CommonTestCases:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**inputs_dict)
|
outputs = model(**inputs_dict)
|
||||||
|
|
||||||
class GPTModelTester(CommonModelTester):
|
|
||||||
|
class GPTModelTester(ModelTesterMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
|
|||||||
@@ -13,10 +13,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import CommonTestCases, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -25,7 +27,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class CTRLModelTest(CommonTestCases.CommonModelTester):
|
class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
|
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import CommonTestCases, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from .utils import require_torch, torch_device
|
from .utils import require_torch, torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -33,7 +35,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class DistilBertModelTest(CommonTestCases.CommonModelTester):
|
class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(DistilBertModel, DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
|
(DistilBertModel, DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import CommonTestCases, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -32,7 +34,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class GPT2ModelTest(CommonTestCases.CommonModelTester):
|
class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||||
|
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import CommonTestCases, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -32,7 +34,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
|
class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) if is_torch_available() else ()
|
(OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) if is_torch_available() else ()
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import unittest
|
|||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import CommonTestCases, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -37,7 +37,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class RobertaModelTest(CommonTestCases.CommonModelTester):
|
class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (RobertaForMaskedLM, RobertaModel) if is_torch_available() else ()
|
all_model_classes = (RobertaForMaskedLM, RobertaModel) if is_torch_available() else ()
|
||||||
|
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import CommonTestCases, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_torch, slow
|
from .utils import CACHE_DIR, require_torch, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -27,7 +29,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class T5ModelTest(CommonTestCases.CommonModelTester):
|
class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (T5Model, T5WithLMHeadModel) if is_torch_available() else ()
|
all_model_classes = (T5Model, T5WithLMHeadModel) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import AlbertConfig, is_tf_available
|
from transformers import AlbertConfig, is_tf_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_tf, slow
|
from .utils import CACHE_DIR, require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -31,7 +33,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester):
|
class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(TFAlbertModel, TFAlbertForMaskedLM, TFAlbertForSequenceClassification) if is_tf_available() else ()
|
(TFAlbertModel, TFAlbertForMaskedLM, TFAlbertForSequenceClassification) if is_tf_available() else ()
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import BertConfig, is_tf_available
|
from transformers import BertConfig, is_tf_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_tf, slow
|
from .utils import CACHE_DIR, require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -36,7 +38,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import random
|
|||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
|
||||||
|
|
||||||
from transformers import is_tf_available, is_torch_available
|
from transformers import is_tf_available, is_torch_available
|
||||||
|
|
||||||
@@ -59,9 +58,8 @@ def _config_zero_init(config):
|
|||||||
return configs_no_init
|
return configs_no_init
|
||||||
|
|
||||||
|
|
||||||
class TFCommonTestCases:
|
@require_tf
|
||||||
@require_tf
|
class TFModelTesterMixin:
|
||||||
class TFCommonModelTester(unittest.TestCase):
|
|
||||||
|
|
||||||
model_tester = None
|
model_tester = None
|
||||||
all_model_classes = ()
|
all_model_classes = ()
|
||||||
@@ -168,12 +166,8 @@ class TFCommonTestCases:
|
|||||||
|
|
||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
input_ids = {
|
input_ids = {
|
||||||
"decoder_input_ids": tf.keras.Input(
|
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
|
||||||
batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"
|
"encoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="encoder_input_ids", dtype="int32"),
|
||||||
),
|
|
||||||
"encoder_input_ids": tf.keras.Input(
|
|
||||||
batch_shape=(2, 2000), name="encoder_input_ids", dtype="int32"
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32")
|
input_ids = tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32")
|
||||||
@@ -209,9 +203,7 @@ class TFCommonTestCases:
|
|||||||
outputs_dict = model(inputs_dict)
|
outputs_dict = model(inputs_dict)
|
||||||
|
|
||||||
inputs_keywords = copy.deepcopy(inputs_dict)
|
inputs_keywords = copy.deepcopy(inputs_dict)
|
||||||
input_ids = inputs_keywords.pop(
|
input_ids = inputs_keywords.pop("input_ids" if not self.is_encoder_decoder else "decoder_input_ids", None)
|
||||||
"input_ids" if not self.is_encoder_decoder else "decoder_input_ids", None
|
|
||||||
)
|
|
||||||
outputs_keywords = model(input_ids, **inputs_keywords)
|
outputs_keywords = model(input_ids, **inputs_keywords)
|
||||||
|
|
||||||
output_dict = outputs_dict[0].numpy()
|
output_dict = outputs_dict[0].numpy()
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import CTRLConfig, is_tf_available
|
from transformers import CTRLConfig, is_tf_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_tf, slow
|
from .utils import CACHE_DIR, require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -26,7 +28,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester):
|
class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel) if is_tf_available() else ()
|
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel) if is_tf_available() else ()
|
||||||
|
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import DistilBertConfig, is_tf_available
|
from transformers import DistilBertConfig, is_tf_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
from .utils import require_tf
|
from .utils import require_tf
|
||||||
|
|
||||||
|
|
||||||
@@ -31,7 +33,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import GPT2Config, is_tf_available
|
from transformers import GPT2Config, is_tf_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_tf, slow
|
from .utils import CACHE_DIR, require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -32,7 +34,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
|
class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel) if is_tf_available() else ()
|
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel) if is_tf_available() else ()
|
||||||
# all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else ()
|
# all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else ()
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import OpenAIGPTConfig, is_tf_available
|
from transformers import OpenAIGPTConfig, is_tf_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_tf, slow
|
from .utils import CACHE_DIR, require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -32,7 +34,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester):
|
class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel) if is_tf_available() else ()
|
(TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel) if is_tf_available() else ()
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import unittest
|
|||||||
from transformers import RobertaConfig, is_tf_available
|
from transformers import RobertaConfig, is_tf_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_tf, slow
|
from .utils import CACHE_DIR, require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester):
|
class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification) if is_tf_available() else ()
|
(TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification) if is_tf_available() else ()
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import T5Config, is_tf_available
|
from transformers import T5Config, is_tf_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_tf, slow
|
from .utils import CACHE_DIR, require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -26,7 +28,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
|
class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
all_model_classes = (TFT5Model, TFT5WithLMHeadModel) if is_tf_available() else ()
|
all_model_classes = (TFT5Model, TFT5WithLMHeadModel) if is_tf_available() else ()
|
||||||
|
|||||||
@@ -15,11 +15,12 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import TransfoXLConfig, is_tf_available
|
from transformers import TransfoXLConfig, is_tf_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_tf, slow
|
from .utils import CACHE_DIR, require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
|
class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else ()
|
all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import is_tf_available
|
from transformers import is_tf_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_tf, slow
|
from .utils import CACHE_DIR, require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -34,7 +36,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
|
class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple)
|
(TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple)
|
||||||
|
|||||||
@@ -15,11 +15,12 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import XLNetConfig, is_tf_available
|
from transformers import XLNetConfig, is_tf_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFCommonTestCases, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_tf, slow
|
from .utils import CACHE_DIR, require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -37,7 +38,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
|
class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -15,11 +15,12 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import CommonTestCases, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -30,7 +31,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TransfoXLModelTest(CommonTestCases.CommonModelTester):
|
class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
|
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import CommonTestCases, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -34,7 +36,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class XLMModelTest(CommonTestCases.CommonModelTester):
|
class XLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -15,11 +15,12 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import CommonTestCases, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
@@ -38,7 +39,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class XLNetModelTest(CommonTestCases.CommonModelTester):
|
class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
|
|||||||
Reference in New Issue
Block a user