Expose get_config() on ModelTesters (#12812)

* Expose get_config() on ModelTesters

* Typo
This commit is contained in:
Lysandre Debut
2021-07-21 10:13:11 +02:00
committed by GitHub
parent cabcc75171
commit c3d9ac7607
53 changed files with 1249 additions and 1193 deletions

View File

@@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch LUKE model. """
import unittest
from transformers import is_torch_available
from transformers import LukeConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
@@ -27,7 +26,6 @@ if is_torch_available():
import torch
from transformers import (
LukeConfig,
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
@@ -154,7 +152,25 @@ class LukeModelTester:
[self.batch_size, self.entity_length], self.num_entity_span_classification_labels
)
config = LukeConfig(
config = self.get_config()
return (
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
)
def get_config(self):
return LukeConfig(
vocab_size=self.vocab_size,
entity_vocab_size=self.entity_vocab_size,
entity_emb_size=self.entity_emb_size,
@@ -172,21 +188,6 @@ class LukeModelTester:
use_entity_aware_attention=self.use_entity_aware_attention,
)
return (
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
)
def create_and_check_model(
self,
config,