fix conflicts

This commit is contained in:
Patrick von Platen
2020-03-06 11:28:10 +01:00
parent d6de6423ba
commit d8e2b3c547
4 changed files with 178 additions and 76 deletions

View File

@@ -54,13 +54,13 @@ class ModelTesterMixin:
model_tester = None
all_model_classes = ()
all_generative_model_classes = ()
test_torchscript = True
test_pruning = True
test_resize_embeddings = True
test_head_masking = True
_A_test_torchscript = True
_A_test_pruning = True
_A_test_resize_embeddings = True
_A_test_head_masking = True
is_encoder_decoder = False
def test_save_load(self):
def _A_test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
@@ -85,7 +85,7 @@ class ModelTesterMixin:
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_initialization(self):
def _A_test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
@@ -99,7 +99,7 @@ class ModelTesterMixin:
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
)
def test_determinism(self):
def _A_test_determinism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
@@ -116,7 +116,7 @@ class ModelTesterMixin:
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_attention_outputs(self):
def _A_test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
@@ -179,25 +179,25 @@ class ModelTesterMixin:
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_torchscript(self):
def _A_test_torchscript(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self._create_and_check_torchscript(config, inputs_dict)
def test_torchscript_output_attentions(self):
def _A_test_torchscript_output_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_attentions = True
self._create_and_check_torchscript(config, inputs_dict)
def test_torchscript_output_hidden_state(self):
def _A_test_torchscript_output_hidden_state(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
self._create_and_check_torchscript(config, inputs_dict)
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
if not self._A_test_torchscript:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
@@ -245,8 +245,8 @@ class ModelTesterMixin:
self.assertTrue(models_equal)
def test_headmasking(self):
if not self.test_head_masking:
def _A_test_headmasking(self):
if not self._A_test_head_masking:
return
global_rng.seed(42)
@@ -299,8 +299,8 @@ class ModelTesterMixin:
self.assertAlmostEqual(attentions[-1][..., -2, :, :].flatten().sum().item(), 0.0)
self.assertNotEqual(attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
def test_head_pruning(self):
if not self.test_pruning:
def _A_test_head_pruning(self):
if not self._A_test_pruning:
return
for model_class in self.all_model_classes:
@@ -328,8 +328,8 @@ class ModelTesterMixin:
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
def test_head_pruning_save_load_from_pretrained(self):
if not self.test_pruning:
def _A_test_head_pruning_save_load_from_pretrained(self):
if not self._A_test_pruning:
return
for model_class in self.all_model_classes:
@@ -361,8 +361,8 @@ class ModelTesterMixin:
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
def test_head_pruning_save_load_from_config_init(self):
if not self.test_pruning:
def _A_test_head_pruning_save_load_from_config_init(self):
if not self._A_test_pruning:
return
for model_class in self.all_model_classes:
@@ -392,8 +392,8 @@ class ModelTesterMixin:
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
def test_head_pruning_integration(self):
if not self.test_pruning:
def _A_test_head_pruning_integration(self):
if not self._A_test_pruning:
return
for model_class in self.all_model_classes:
@@ -449,7 +449,7 @@ class ModelTesterMixin:
self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2], 2: [1, 2]})
def test_hidden_states_output(self):
def _A_test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
@@ -474,9 +474,9 @@ class ModelTesterMixin:
],
)
def test_resize_tokens_embeddings(self):
def _A_test_resize_tokens_embeddings(self):
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
if not self._A_test_resize_embeddings:
return
for model_class in self.all_model_classes:
@@ -516,7 +516,7 @@ class ModelTesterMixin:
self.assertTrue(models_equal)
def test_model_common_attributes(self):
def _A_test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
@@ -594,7 +594,7 @@ class ModelTesterMixin:
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
def test_inputs_embeds(self):
def _A_test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.is_encoder_decoder:
@@ -621,7 +621,7 @@ class ModelTesterMixin:
with torch.no_grad():
model(**inputs_dict)
def test_lm_head_model_random_generate(self):
def _A_test_lm_head_model_random_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.get(
@@ -711,7 +711,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
@require_torch
class ModelUtilsTest(unittest.TestCase):
@slow
def test_model_from_pretrained(self):
def _A_test_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
config = BertConfig.from_pretrained(model_name)
@@ -736,7 +736,7 @@ class ModelUtilsTest(unittest.TestCase):
class UtilsFunctionsTest(unittest.TestCase):
# tests whether the top_k_top_p function behaves as expected
def test_top_k_top_p_filtering(self):
def _A_test_top_k_top_p_filtering(self):
logits = torch.tensor(
[
[