From d4886173b26d772b34afe20efa029b54a0f356a0 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 7 Jul 2020 10:06:48 -0400 Subject: [PATCH] [Bart] enable test_torchscript, update test_tie_weights (#5457) * Passing all but one torchscript test * Style * move comment * remove unneeded assert --- tests/test_modeling_bart.py | 3 +-- tests/test_modeling_common.py | 5 ----- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 4dbfdd1ccd..26c25806f5 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -120,7 +120,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase): is_encoder_decoder = True # TODO(SS): fix the below in a separate PR test_pruning = False - test_torchscript = False + test_torchscript = True test_head_masking = False test_resize_embeddings = True # This requires inputs_dict['input_ids'] test_missing_keys = False # because BartForConditionalGeneration and BartModel now have identical state_dict @@ -133,7 +133,6 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase): self.config_tester.run_common_tests() def test_initialization_more(self): - # (config, input_ids, token_type_ids, input_mask, *unused) = \ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = BartModel(config) model.to(torch_device) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c9b8eecc27..7be7ae87d5 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -612,15 +612,11 @@ class ModelTesterMixin: if model_not_tied.get_output_embeddings() is None: continue - params_not_tied = list(model_not_tied.parameters()) - config_tied = copy.deepcopy(config) config_tied.torchscript = False model_tied = model_class(config_tied) params_tied = list(model_tied.parameters()) - # Check that the embedding layer and decoding layer are the same in size and in value - self.assertGreater(len(params_not_tied), len(params_tied)) # self.assertTrue(check_same_values(embeddings, decoding)) # # Check that after modification, they remain the same. @@ -638,7 +634,6 @@ class ModelTesterMixin: # Check that after resize they remain tied. model_tied.resize_token_embeddings(config.vocab_size + 10) params_tied_2 = list(model_tied.parameters()) - self.assertGreater(len(params_not_tied), len(params_tied)) self.assertEqual(len(params_tied_2), len(params_tied)) # decoding.weight.data.mul_(20)