[multiple models] skip saving/loading deterministic state_dict keys (#7878)

* make the save_load special key tests common

* handle mbart

* cleaner solution

* fix

* move test_save_load_missing_keys back into fstm for now

* restore

* style

* add marian

* add pegasus

* blenderbot

* revert - no static embed
This commit is contained in:
Stas Bekman
2020-10-21 05:06:07 -07:00
committed by GitHub
parent 006a16483f
commit 57516c0cc8
8 changed files with 144 additions and 26 deletions

View File

@@ -5,6 +5,7 @@ from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_modeling_bart import TOLERANCE, _long_tensor, assert_tensors_close
from .test_modeling_common import ModelTesterMixin
if is_torch_available():
@@ -23,6 +24,37 @@ EN_CODE = 250004
RO_CODE = 250020
@require_torch
class ModelTester:
def __init__(self, parent):
self.config = MBartConfig(
vocab_size=99,
d_model=24,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
max_position_embeddings=48,
add_final_layer_norm=True,
return_dict=True,
)
def prepare_config_and_inputs_for_common(self):
return self.config, {}
@require_torch
class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save
def setUp(self):
self.model_tester = ModelTester(self)
@require_torch
@require_sentencepiece
@require_tokenizers