Fix TF Longformer (#9348)
* Fix longformer * Apply style * Remove serving content * Forgot a condition * Apply style * Address Patrick's comments * Fix dtype
This commit is contained in:
@@ -390,7 +390,7 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
|
||||
True` else after `sep_token_id`.
|
||||
"""
|
||||
|
||||
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
|
||||
assert shape_list(sep_token_indices)[1] == 2, "`input_ids` should have two dimensions"
|
||||
question_end_index = tf.reshape(sep_token_indices, (input_ids_shape[0], 3, 2))[:, 0, 1]
|
||||
question_end_index = tf.cast(question_end_index[:, None], tf.dtypes.int32) # size: batch_size x 1
|
||||
# bool attention mask with True in locations of global attention
|
||||
@@ -1028,7 +1028,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
# pad to full matrix
|
||||
padding = tf.constant(
|
||||
padding = tf.convert_to_tensor(
|
||||
[[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]]
|
||||
)
|
||||
|
||||
@@ -1523,8 +1523,7 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
||||
training=False,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_global_attentions = () if (output_attentions and is_global_attn) else None
|
||||
all_attentions = all_global_attentions = () if output_attentions else None
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
@@ -1547,7 +1546,6 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
||||
# bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
|
||||
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
|
||||
|
||||
if is_global_attn:
|
||||
# bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
|
||||
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)))
|
||||
|
||||
@@ -1766,7 +1764,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
)
|
||||
)
|
||||
|
||||
paddings = tf.constant([[0, 0], [0, padding_len]])
|
||||
paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
|
||||
@@ -1776,13 +1774,15 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
||||
position_ids = tf.pad(position_ids, paddings, constant_values=pad_token_id)
|
||||
|
||||
if inputs_embeds is not None:
|
||||
|
||||
def pad_embeddings():
|
||||
input_ids_padding = tf.fill((batch_size, padding_len), self.pad_token_id)
|
||||
inputs_embeds_padding = self.embeddings(input_ids_padding)
|
||||
inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
|
||||
return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
|
||||
|
||||
attention_mask = tf.pad(
|
||||
attention_mask, paddings, constant_values=False
|
||||
) # no attention on the padding tokens
|
||||
inputs_embeds = tf.cond(padding_len > 0, pad_embeddings, lambda: inputs_embeds)
|
||||
|
||||
attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
|
||||
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0
|
||||
|
||||
return (
|
||||
@@ -2171,16 +2171,14 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
||||
|
||||
# set global attention on question tokens
|
||||
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None:
|
||||
if inputs["input_ids"] is None:
|
||||
logger.warning(
|
||||
"It is not possible to automatically generate the `global_attention_mask`. Please make sure that it is correctly set."
|
||||
)
|
||||
elif (
|
||||
tf.where(inputs["input_ids"] == self.config.sep_token_id).shape[0] != 3 * inputs["input_ids"].shape[0]
|
||||
if (
|
||||
shape_list(tf.where(inputs["input_ids"] == self.config.sep_token_id))[0]
|
||||
!= 3 * shape_list(inputs["input_ids"])[0]
|
||||
):
|
||||
logger.warning(
|
||||
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error."
|
||||
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error. The global attention is disabled for this forward pass."
|
||||
)
|
||||
inputs["global_attention_mask"] = tf.fill(shape_list(inputs["input_ids"]), value=0)
|
||||
else:
|
||||
logger.info("Initializing global attention on question tokens...")
|
||||
# put global attention on all tokens until `config.sep_token_id` is reached
|
||||
@@ -2317,8 +2315,8 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
|
||||
inputs["global_attention_mask"] = tf.zeros_like(inputs["input_ids"])
|
||||
inputs["global_attention_mask"] = tf.tensor_scatter_nd_update(
|
||||
inputs["global_attention_mask"],
|
||||
[[i, 0] for i in range(inputs["input_ids"].shape[0])],
|
||||
[1 for _ in range(inputs["input_ids"].shape[0])],
|
||||
[[i, 0] for i in range(shape_list(inputs["input_ids"])[0])],
|
||||
[1 for _ in range(shape_list(inputs["input_ids"])[0])],
|
||||
)
|
||||
|
||||
outputs = self.longformer(
|
||||
@@ -2443,7 +2441,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
|
||||
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
|
||||
)
|
||||
flat_global_attention_mask = (
|
||||
tf.reshape(inputs["global_attention_mask"], (-1, inputs["global_attention_mask"].shape[-1]))
|
||||
tf.reshape(inputs["global_attention_mask"], (-1, shape_list(inputs["global_attention_mask"])[-1]))
|
||||
if inputs["global_attention_mask"] is not None
|
||||
else None
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user