[TF Longformer] Improve Speed for TF Longformer (#6447)
* add tf graph compile tests * fix conflict * remove more tf transpose statements * fix conflicts * fix comment typos * move function to class function * fix black * fix black * make style
This commit is contained in:
committed by
GitHub
parent
a75c64d80c
commit
858b7d5873
@@ -110,15 +110,6 @@ class TFModelTesterMixin:
|
||||
|
||||
def test_initialization(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# configs_no_init = _config_zero_init(config)
|
||||
# for model_class in self.all_model_classes:
|
||||
# model = model_class(config=configs_no_init)
|
||||
# for name, param in model.named_parameters():
|
||||
# if param.requires_grad:
|
||||
# self.assertIn(param.data.mean().item(), [0.0, 1.0],
|
||||
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
||||
|
||||
def test_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -134,6 +125,19 @@ class TFModelTesterMixin:
|
||||
|
||||
self.assert_outputs_same(after_outputs, outputs)
|
||||
|
||||
def test_graph_mode(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
@tf.function
|
||||
def run_in_graph_mode():
|
||||
return model(inputs)
|
||||
|
||||
outputs = run_in_graph_mode()
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -385,15 +385,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(shape_list(hidden_states), [1, 8, 4])
|
||||
|
||||
# pad along seq length dim
|
||||
paddings = tf.constant([[0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
|
||||
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
|
||||
|
||||
hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2)
|
||||
padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings)
|
||||
self.assertTrue(shape_list(padded_hidden_states) == [1, 8, 5])
|
||||
self.assertTrue(shape_list(padded_hidden_states) == [1, 1, 8, 5])
|
||||
|
||||
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
|
||||
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, -1, :], rtol=1e-6)
|
||||
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6)
|
||||
tf.debugging.assert_near(
|
||||
hidden_states[0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
|
||||
hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
|
||||
)
|
||||
|
||||
def test_mask_invalid_locations(self):
|
||||
@@ -437,10 +438,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
||||
hidden_states = self._get_hidden_states()
|
||||
batch_size, seq_length, hidden_size = hidden_states.shape
|
||||
|
||||
attention_mask = tf.zeros((batch_size, 1, 1, seq_length), dtype=tf.dtypes.float32)
|
||||
attention_mask = tf.where(tf.range(4)[None, None, None, :] > 1, -10000.0, attention_mask)
|
||||
attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32)
|
||||
is_index_global_attn = tf.math.greater(attention_mask, 1)
|
||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||
|
||||
output_hidden_states = layer([hidden_states, attention_mask, None])[0]
|
||||
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None])
|
||||
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
|
||||
|
||||
output_hidden_states = layer(
|
||||
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, None]
|
||||
)[0]
|
||||
|
||||
expected_slice = tf.convert_to_tensor(
|
||||
[0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32
|
||||
@@ -461,12 +468,18 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
||||
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
|
||||
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
|
||||
|
||||
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 1, 10000.0, attention_mask_1)
|
||||
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 2, -10000.0, attention_mask_1)
|
||||
attention_mask_2 = tf.where(tf.range(4)[None, None, None, :] > 0, 10000.0, attention_mask_2)
|
||||
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 1, 10000.0, attention_mask_1)
|
||||
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1)
|
||||
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] > 0, 10000.0, attention_mask_2)
|
||||
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
|
||||
|
||||
output_hidden_states = layer([hidden_states, attention_mask, None])[0]
|
||||
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
|
||||
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
|
||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||
|
||||
output_hidden_states = layer(
|
||||
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn, None]
|
||||
)[0]
|
||||
|
||||
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
|
||||
expected_slice_0 = tf.convert_to_tensor(
|
||||
|
||||
Reference in New Issue
Block a user