Expose get_config() on ModelTesters (#12812)

* Expose get_config() on ModelTesters

* Typo
This commit is contained in:
Lysandre Debut
2021-07-21 10:13:11 +02:00
committed by GitHub
parent cabcc75171
commit c3d9ac7607
53 changed files with 1249 additions and 1193 deletions

View File

@@ -19,7 +19,7 @@ import unittest
import timeout_decorator # noqa
from parameterized import parameterized
from transformers import is_torch_available
from transformers import FSMTConfig, is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
@@ -32,7 +32,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import FSMTConfig, FSMTForConditionalGeneration, FSMTModel, FSMTTokenizer
from transformers import FSMTForConditionalGeneration, FSMTModel, FSMTTokenizer
from transformers.models.fsmt.modeling_fsmt import (
SinusoidalPositionalEmbedding,
_prepare_fsmt_decoder_inputs,
@@ -42,8 +42,7 @@ if is_torch_available():
from transformers.pipelines import TranslationPipeline
@require_torch
class ModelTester:
class FSMTModelTester:
def __init__(
self,
parent,
@@ -78,7 +77,12 @@ class ModelTester:
)
input_ids[:, -1] = 2 # Eos Token
config = FSMTConfig(
config = self.get_config()
inputs_dict = prepare_fsmt_inputs_dict(config, input_ids)
return config, inputs_dict
def get_config(self):
return FSMTConfig(
vocab_size=self.src_vocab_size, # hack needed for common tests
src_vocab_size=self.src_vocab_size,
tgt_vocab_size=self.tgt_vocab_size,
@@ -97,8 +101,6 @@ class ModelTester:
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
)
inputs_dict = prepare_fsmt_inputs_dict(config, input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
@@ -141,7 +143,7 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
test_missing_keys = False
def setUp(self):
self.model_tester = ModelTester(self)
self.model_tester = FSMTModelTester(self)
self.langs = ["en", "ru"]
config = {
"langs": self.langs,