[TF Longformer] Improve Speed for TF Longformer (#6447)

* add tf graph compile tests

* fix conflict

* remove more tf transpose statements

* fix conflicts

* fix comment typos

* move function to class function

* fix black

* fix black

* make style
This commit is contained in:
Patrick von Platen
2020-08-26 20:55:41 +02:00
committed by GitHub
parent a75c64d80c
commit 858b7d5873
5 changed files with 258 additions and 135 deletions

View File

@@ -385,15 +385,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
self.assertTrue(shape_list(hidden_states), [1, 8, 4])
# pad along seq length dim
paddings = tf.constant([[0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2)
padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings)
self.assertTrue(shape_list(padded_hidden_states) == [1, 8, 5])
self.assertTrue(shape_list(padded_hidden_states) == [1, 1, 8, 5])
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, -1, :], rtol=1e-6)
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6)
tf.debugging.assert_near(
hidden_states[0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
)
def test_mask_invalid_locations(self):
@@ -437,10 +438,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.shape
attention_mask = tf.zeros((batch_size, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask = tf.where(tf.range(4)[None, None, None, :] > 1, -10000.0, attention_mask)
attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32)
is_index_global_attn = tf.math.greater(attention_mask, 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
output_hidden_states = layer([hidden_states, attention_mask, None])[0]
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None])
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
output_hidden_states = layer(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, None]
)[0]
expected_slice = tf.convert_to_tensor(
[0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32
@@ -461,12 +468,18 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 1, 10000.0, attention_mask_1)
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 2, -10000.0, attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, None, None, :] > 0, 10000.0, attention_mask_2)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 1, 10000.0, attention_mask_1)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] > 0, 10000.0, attention_mask_2)
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
output_hidden_states = layer([hidden_states, attention_mask, None])[0]
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
output_hidden_states = layer(
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn, None]
)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
expected_slice_0 = tf.convert_to_tensor(