adding more tests on TF and pytorch serialization - updating configuration for better serialization
This commit is contained in:
@@ -17,8 +17,10 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import sys
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import json
|
||||
import random
|
||||
import uuid
|
||||
@@ -31,6 +33,7 @@ from transformers import is_torch_available
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from transformers import (PretrainedConfig, PreTrainedModel,
|
||||
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
@@ -38,6 +41,20 @@ if is_torch_available():
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
|
||||
class TemporaryDirectory(object):
|
||||
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
|
||||
def __enter__(self):
|
||||
self.name = tempfile.mkdtemp()
|
||||
return self.name
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
shutil.rmtree(self.name)
|
||||
else:
|
||||
import pickle
|
||||
TemporaryDirectory = tempfile.TemporaryDirectory
|
||||
unicode = str
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
@@ -57,6 +74,23 @@ class CommonTestCases:
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = True
|
||||
|
||||
def test_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict)
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
with torch.no_grad():
|
||||
after_outputs = model(**inputs_dict)
|
||||
max_diff = np.amax(np.abs(after_outputs[0].numpy() - outputs[0].numpy()))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user