Expose get_config() on ModelTesters (#12812)

* Expose get_config() on ModelTesters

* Typo
This commit is contained in:
Lysandre Debut
2021-07-21 10:13:11 +02:00
committed by GitHub
parent cabcc75171
commit c3d9ac7607
53 changed files with 1249 additions and 1193 deletions

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
@@ -29,6 +28,7 @@ from transformers import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TapasConfig,
is_torch_available,
)
from transformers.file_utils import cached_property
@@ -43,7 +43,6 @@ if is_torch_available():
import torch
from transformers import (
TapasConfig,
TapasForMaskedLM,
TapasForQuestionAnswering,
TapasForSequenceClassification,
@@ -183,7 +182,24 @@ class TapasModelTester:
float_answer = floats_tensor([self.batch_size]).to(torch_device)
aggregation_labels = ids_tensor([self.batch_size], self.num_aggregation_labels).to(torch_device)
config = TapasConfig(
config = self.get_config()
return (
config,
input_ids,
input_mask,
token_type_ids,
sequence_labels,
token_labels,
labels,
numeric_values,
numeric_values_scale,
float_answer,
aggregation_labels,
)
def get_config(self):
return TapasConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
@@ -220,20 +236,6 @@ class TapasModelTester:
disable_per_token_loss=self.disable_per_token_loss,
)
return (
config,
input_ids,
input_mask,
token_type_ids,
sequence_labels,
token_labels,
labels,
numeric_values,
numeric_values_scale,
float_answer,
aggregation_labels,
)
def create_and_check_model(
self,
config,