Fix Longformer and LED (#9942)

* Fix Longformer and LED

* Add a test for graph execution with inputs_embeds

* Apply style
This commit is contained in:
Julien Plu
2021-02-03 12:26:32 +01:00
committed by GitHub
parent d55e10beab
commit 3f77c26d74
3 changed files with 42 additions and 12 deletions

View File

@@ -1665,7 +1665,6 @@ class TFLEDEncoder(tf.keras.layers.Layer):
def compute_hidden_states(self, hidden_states, padding_len):
return hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
@tf.function
def _pad_to_window_size(
self,
input_ids,
@@ -1685,26 +1684,28 @@ class TFLEDEncoder(tf.keras.layers.Layer):
batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window
if padding_len > 0:
if tf.math.greater(padding_len, 0):
logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seq_len, seq_len + padding_len, attention_window
)
)
paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])
paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])
if input_ids is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
if input_ids is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
if inputs_embeds is not None:
if inputs_embeds is not None:
def pad_embeddings():
input_ids_padding = tf.fill((batch_size, padding_len), pad_token_id)
inputs_embeds_padding = self.embed_tokens(input_ids_padding)
inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
attention_mask = tf.pad(
attention_mask, paddings, constant_values=False
) # no attention on the padding tokens
inputs_embeds = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: inputs_embeds)
attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
return (
padding_len,

View File

@@ -1836,7 +1836,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window
if padding_len > 0:
if tf.math.greater(padding_len, 0):
logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seq_len, seq_len + padding_len, attention_window
@@ -1859,7 +1859,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
inputs_embeds_padding = self.embeddings(input_ids_padding)
return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
inputs_embeds = tf.cond(padding_len > 0, pad_embeddings, lambda: inputs_embeds)
inputs_embeds = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: inputs_embeds)
attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0