[multiple models] skip saving/loading deterministic state_dict keys (#7878)

* make the save_load special key tests common

* handle mbart

* cleaner solution

* fix

* move test_save_load_missing_keys back into fstm for now

* restore

* style

* add marian

* add pegasus

* blenderbot

* revert - no static embed
This commit is contained in:
Stas Bekman
2020-10-21 05:06:07 -07:00
committed by GitHub
parent 006a16483f
commit 57516c0cc8
8 changed files with 144 additions and 26 deletions

View File

@@ -21,6 +21,8 @@ from transformers.file_utils import cached_property
from transformers.hf_api import HfApi
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_modeling_common import ModelTesterMixin
if is_torch_available():
import torch
@@ -35,6 +37,37 @@ if is_torch_available():
from transformers.pipelines import TranslationPipeline
@require_torch
class ModelTester:
def __init__(self, parent):
self.config = MarianConfig(
vocab_size=99,
d_model=24,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
max_position_embeddings=48,
add_final_layer_norm=True,
return_dict=True,
)
def prepare_config_and_inputs_for_common(self):
return self.config, {}
@require_torch
class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (MarianMTModel,) if is_torch_available() else ()
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save
def setUp(self):
self.model_tester = ModelTester(self)
class ModelManagementTests(unittest.TestCase):
@slow
def test_model_names(self):