[TF Longformer] Improve Speed for TF Longformer (#6447)

* add tf graph compile tests

* fix conflict

* remove more tf transpose statements

* fix conflicts

* fix comment typos

* move function to class function

* fix black

* fix black

* make style
This commit is contained in:
Patrick von Platen
2020-08-26 20:55:41 +02:00
committed by GitHub
parent a75c64d80c
commit 858b7d5873
5 changed files with 258 additions and 135 deletions

View File

@@ -110,15 +110,6 @@ class TFModelTesterMixin:
def test_initialization(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# configs_no_init = _config_zero_init(config)
# for model_class in self.all_model_classes:
# model = model_class(config=configs_no_init)
# for name, param in model.named_parameters():
# if param.requires_grad:
# self.assertIn(param.data.mean().item(), [0.0, 1.0],
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -134,6 +125,19 @@ class TFModelTesterMixin:
self.assert_outputs_same(after_outputs, outputs)
def test_graph_mode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@tf.function
def run_in_graph_mode():
return model(inputs)
outputs = run_in_graph_mode()
self.assertIsNotNone(outputs)
@slow
def test_saved_model_with_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()