[TF Longformer] Improve Speed for TF Longformer (#6447)
* add tf graph compile tests * fix conflict * remove more tf transpose statements * fix conflicts * fix comment typos * move function to class function * fix black * fix black * make style
This commit is contained in:
committed by
GitHub
parent
a75c64d80c
commit
858b7d5873
@@ -110,15 +110,6 @@ class TFModelTesterMixin:
|
||||
|
||||
def test_initialization(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# configs_no_init = _config_zero_init(config)
|
||||
# for model_class in self.all_model_classes:
|
||||
# model = model_class(config=configs_no_init)
|
||||
# for name, param in model.named_parameters():
|
||||
# if param.requires_grad:
|
||||
# self.assertIn(param.data.mean().item(), [0.0, 1.0],
|
||||
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
||||
|
||||
def test_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -134,6 +125,19 @@ class TFModelTesterMixin:
|
||||
|
||||
self.assert_outputs_same(after_outputs, outputs)
|
||||
|
||||
def test_graph_mode(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
@tf.function
|
||||
def run_in_graph_mode():
|
||||
return model(inputs)
|
||||
|
||||
outputs = run_in_graph_mode()
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user