Deberta tf (#12972)
* TFDeberta moved weights to build and fixed name scope added missing , bug fixes to enable graph mode execution updated setup.py fixing typo fix imports embedding mask fix added layer names avoid autmatic incremental names +XSoftmax cleanup added names to layer disable keras_serializable Distangled attention output shape hidden_size==None using symbolic inputs test for Deberta tf make style Update src/transformers/models/deberta/modeling_tf_deberta.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Update src/transformers/models/deberta/modeling_tf_deberta.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Update src/transformers/models/deberta/modeling_tf_deberta.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Update src/transformers/models/deberta/modeling_tf_deberta.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Update src/transformers/models/deberta/modeling_tf_deberta.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Update src/transformers/models/deberta/modeling_tf_deberta.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Update src/transformers/models/deberta/modeling_tf_deberta.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> removed tensorflow-probability removed blank line * removed tf experimental api +torch_gather tf implementation from @Rocketknight1 * layername DeBERTa --> deberta * copyright fix * added docs for TFDeberta & make style * layer_name change to fix load from pt model * layer_name change as pt model * SequenceClassification layername change, to same as pt model * switched to keras built-in LayerNormalization * added `TFDeberta` prefix most layer classes * updated to tf.Tensor in the docstring
This commit is contained in:
@@ -343,7 +343,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| DeBERTa | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| DeBERTa-v2 | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
|
||||
@@ -38,7 +38,8 @@ the training data performs consistently better on a wide range of NLP tasks, ach
|
||||
pre-trained models will be made publicly available at https://github.com/microsoft/DeBERTa.*
|
||||
|
||||
|
||||
This model was contributed by `DeBERTa <https://huggingface.co/DeBERTa>`__. The original code can be found `here
|
||||
This model was contributed by `DeBERTa <https://huggingface.co/DeBERTa>`__. This model TF 2.0 implementation was
|
||||
contributed by `kamalkraj <https://huggingface.co/kamalkraj>`__ . The original code can be found `here
|
||||
<https://github.com/microsoft/DeBERTa>`__.
|
||||
|
||||
|
||||
@@ -103,3 +104,45 @@ DebertaForQuestionAnswering
|
||||
|
||||
.. autoclass:: transformers.DebertaForQuestionAnswering
|
||||
:members: forward
|
||||
|
||||
|
||||
TFDebertaModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFDebertaModel
|
||||
:members: call
|
||||
|
||||
|
||||
TFDebertaPreTrainedModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFDebertaPreTrainedModel
|
||||
:members: call
|
||||
|
||||
|
||||
TFDebertaForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFDebertaForMaskedLM
|
||||
:members: call
|
||||
|
||||
|
||||
TFDebertaForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFDebertaForSequenceClassification
|
||||
:members: call
|
||||
|
||||
|
||||
TFDebertaForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFDebertaForTokenClassification
|
||||
:members: call
|
||||
|
||||
|
||||
TFDebertaForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFDebertaForQuestionAnswering
|
||||
:members: call
|
||||
|
||||
@@ -1297,6 +1297,17 @@ if is_tf_available():
|
||||
"TFCTRLPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.deberta"].extend(
|
||||
[
|
||||
"TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFDebertaForMaskedLM",
|
||||
"TFDebertaForQuestionAnswering",
|
||||
"TFDebertaForSequenceClassification",
|
||||
"TFDebertaForTokenClassification",
|
||||
"TFDebertaModel",
|
||||
"TFDebertaPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.distilbert"].extend(
|
||||
[
|
||||
"TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@@ -2820,6 +2831,15 @@ if TYPE_CHECKING:
|
||||
TFCTRLModel,
|
||||
TFCTRLPreTrainedModel,
|
||||
)
|
||||
from .models.deberta import (
|
||||
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFDebertaForMaskedLM,
|
||||
TFDebertaForQuestionAnswering,
|
||||
TFDebertaForSequenceClassification,
|
||||
TFDebertaForTokenClassification,
|
||||
TFDebertaModel,
|
||||
TFDebertaPreTrainedModel,
|
||||
)
|
||||
from .models.distilbert import (
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFDistilBertForMaskedLM,
|
||||
|
||||
@@ -29,6 +29,7 @@ logger = logging.get_logger(__name__)
|
||||
TF_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
("deberta", "TFDebertaModel"),
|
||||
("rembert", "TFRemBertModel"),
|
||||
("roformer", "TFRoFormerModel"),
|
||||
("convbert", "TFConvBertModel"),
|
||||
@@ -144,6 +145,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Masked LM mapping
|
||||
("deberta", "TFDebertaForMaskedLM"),
|
||||
("rembert", "TFRemBertForMaskedLM"),
|
||||
("roformer", "TFRoFormerForMaskedLM"),
|
||||
("convbert", "TFConvBertForMaskedLM"),
|
||||
@@ -183,6 +185,7 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
("deberta", "TFDebertaForSequenceClassification"),
|
||||
("rembert", "TFRemBertForSequenceClassification"),
|
||||
("roformer", "TFRoFormerForSequenceClassification"),
|
||||
("convbert", "TFConvBertForSequenceClassification"),
|
||||
@@ -211,6 +214,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Question Answering mapping
|
||||
("deberta", "TFDebertaForQuestionAnswering"),
|
||||
("rembert", "TFRemBertForQuestionAnswering"),
|
||||
("roformer", "TFRoFormerForQuestionAnswering"),
|
||||
("convbert", "TFConvBertForQuestionAnswering"),
|
||||
@@ -234,6 +238,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Token Classification mapping
|
||||
("deberta", "TFDebertaForTokenClassification"),
|
||||
("rembert", "TFRemBertForTokenClassification"),
|
||||
("roformer", "TFRoFormerForTokenClassification"),
|
||||
("convbert", "TFConvBertForTokenClassification"),
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available
|
||||
from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@@ -40,6 +40,17 @@ if is_torch_available():
|
||||
"DebertaPreTrainedModel",
|
||||
]
|
||||
|
||||
if is_tf_available():
|
||||
_import_structure["modeling_tf_deberta"] = [
|
||||
"TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFDebertaForMaskedLM",
|
||||
"TFDebertaForQuestionAnswering",
|
||||
"TFDebertaForSequenceClassification",
|
||||
"TFDebertaForTokenClassification",
|
||||
"TFDebertaModel",
|
||||
"TFDebertaPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
|
||||
@@ -59,6 +70,18 @@ if TYPE_CHECKING:
|
||||
DebertaPreTrainedModel,
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_deberta import (
|
||||
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFDebertaForMaskedLM,
|
||||
TFDebertaForQuestionAnswering,
|
||||
TFDebertaForSequenceClassification,
|
||||
TFDebertaForTokenClassification,
|
||||
TFDebertaModel,
|
||||
TFDebertaPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
1555
src/transformers/models/deberta/modeling_tf_deberta.py
Normal file
1555
src/transformers/models/deberta/modeling_tf_deberta.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -634,6 +634,63 @@ class TFCTRLPreTrainedModel:
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
TF_DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class TFDebertaForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFDebertaForQuestionAnswering:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFDebertaForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFDebertaForTokenClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFDebertaModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFDebertaPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
||||
282
tests/test_modeling_tf_deberta.py
Normal file
282
tests/test_modeling_tf_deberta.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import DebertaConfig, is_tf_available
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import (
|
||||
TFDebertaForMaskedLM,
|
||||
TFDebertaForQuestionAnswering,
|
||||
TFDebertaForSequenceClassification,
|
||||
TFDebertaForTokenClassification,
|
||||
TFDebertaModel,
|
||||
)
|
||||
|
||||
|
||||
class TFDebertaModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = 13
|
||||
self.seq_length = 7
|
||||
self.is_training = True
|
||||
self.use_input_mask = True
|
||||
self.use_token_type_ids = True
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
self.hidden_dropout_prob = 0.1
|
||||
self.attention_probs_dropout_prob = 0.1
|
||||
self.max_position_embeddings = 512
|
||||
self.type_vocab_size = 16
|
||||
self.relative_attention = False
|
||||
self.max_relative_positions = -1
|
||||
self.position_biased_input = True
|
||||
self.type_sequence_label_size = 2
|
||||
self.initializer_range = 0.02
|
||||
self.num_labels = 3
|
||||
self.num_choices = 4
|
||||
self.scope = None
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
|
||||
config = DebertaConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
relative_attention=self.relative_attention,
|
||||
max_relative_positions=self.max_relative_positions,
|
||||
position_biased_input=self.position_biased_input,
|
||||
initializer_range=self.initializer_range,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFDebertaModel(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
|
||||
inputs = [input_ids, input_mask]
|
||||
result = model(inputs)
|
||||
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFDebertaForMaskedLM(config=config)
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFDebertaForSequenceClassification(config=config)
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
result = model(inputs)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFDebertaForTokenClassification(config=config)
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFDebertaForQuestionAnswering(config=config)
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
result = model(inputs)
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFDebertaModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
TFDebertaModel,
|
||||
TFDebertaForMaskedLM,
|
||||
TFDebertaForQuestionAnswering,
|
||||
TFDebertaForSequenceClassification,
|
||||
TFDebertaForTokenClassification,
|
||||
)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFDebertaModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=DebertaConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = TFDebertaModel.from_pretrained("kamalkraj/deberta-base")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFDeBERTaModelIntegrationTest(unittest.TestCase):
|
||||
@unittest.skip(reason="Model not available yet")
|
||||
def test_inference_masked_lm(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = TFDebertaModel.from_pretrained("kamalkraj/deberta-base")
|
||||
input_ids = tf.constant([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
attention_mask = tf.constant([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
|
||||
output = model(input_ids, attention_mask=attention_mask)[0]
|
||||
|
||||
expected_slice = tf.constant(
|
||||
[
|
||||
[
|
||||
[-0.59855896, -0.80552566, -0.8462135],
|
||||
[1.4484025, -0.93483794, -0.80593085],
|
||||
[0.3122741, 0.00316059, -1.4131377],
|
||||
]
|
||||
]
|
||||
)
|
||||
tf.debugging.assert_near(output[:, 1:4, 1:4], expected_slice, atol=1e-4)
|
||||
Reference in New Issue
Block a user