Fix LongformerModel hidden states (#15537)
* add undo padding * fix * fix tuple issue * make style and quality * move unpad logic to LongformerEncoder + unpad attentions + update tests * move unpad logic to TFLongformerEncoder Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1246,6 +1246,7 @@ class LongformerEncoder(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
padding_len=0,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
@@ -1308,6 +1309,16 @@ class LongformerEncoder(nn.Module):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
# undo padding
|
||||||
|
if padding_len > 0:
|
||||||
|
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
|
||||||
|
hidden_states = hidden_states[:, :-padding_len]
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = tuple([state[:, :-padding_len] for state in all_hidden_states])
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_attentions = tuple([state[:, :, :-padding_len, :] for state in all_attentions])
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None
|
v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None
|
||||||
@@ -1697,6 +1708,7 @@ class LongformerModel(LongformerPreTrainedModel):
|
|||||||
embedding_output,
|
embedding_output,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
|
padding_len=padding_len,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -1704,11 +1716,6 @@ class LongformerModel(LongformerPreTrainedModel):
|
|||||||
sequence_output = encoder_outputs[0]
|
sequence_output = encoder_outputs[0]
|
||||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||||
|
|
||||||
# undo padding
|
|
||||||
if padding_len > 0:
|
|
||||||
# unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1)
|
|
||||||
sequence_output = sequence_output[:, :-padding_len]
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||||
|
|
||||||
|
|||||||
@@ -1587,13 +1587,23 @@ class TFLongformerEncoder(tf.keras.layers.Layer):
|
|||||||
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
|
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
|
||||||
|
|
||||||
# bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
|
# bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
|
||||||
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)))
|
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),)
|
||||||
|
|
||||||
# Add last layer
|
# Add last layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
||||||
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
|
all_hidden_states = all_hidden_states + (hidden_states_to_add,)
|
||||||
|
|
||||||
|
# undo padding
|
||||||
|
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
|
||||||
|
hidden_states = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
||||||
|
if output_attentions:
|
||||||
|
all_attentions = (
|
||||||
|
tuple([state[:, :, :-padding_len, :] for state in all_attentions])
|
||||||
|
if padding_len > 0
|
||||||
|
else all_attentions
|
||||||
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None
|
v for v in [hidden_states, all_hidden_states, all_attentions, all_global_attentions] if v is not None
|
||||||
@@ -1763,11 +1773,6 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
|
|||||||
sequence_output = encoder_outputs[0]
|
sequence_output = encoder_outputs[0]
|
||||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||||
|
|
||||||
# undo padding
|
|
||||||
if padding_len > 0:
|
|
||||||
# unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1)
|
|
||||||
sequence_output = sequence_output[:, :-padding_len]
|
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
return (
|
return (
|
||||||
sequence_output,
|
sequence_output,
|
||||||
|
|||||||
@@ -74,12 +74,6 @@ class LongformerModelTester:
|
|||||||
# is x + self.attention_window + 1, where x is the number of tokens with global attention)
|
# is x + self.attention_window + 1, where x is the number of tokens with global attention)
|
||||||
self.key_length = self.attention_window + 2
|
self.key_length = self.attention_window + 2
|
||||||
|
|
||||||
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
|
|
||||||
# the `test_attention_outputs` and `test_hidden_states_output` tests
|
|
||||||
self.encoder_seq_length = (
|
|
||||||
self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window
|
|
||||||
)
|
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
|||||||
@@ -74,12 +74,6 @@ class TFLongformerModelTester:
|
|||||||
# because its local attention only attends to `self.attention_window` and one before and one after
|
# because its local attention only attends to `self.attention_window` and one before and one after
|
||||||
self.key_length = self.attention_window + 2
|
self.key_length = self.attention_window + 2
|
||||||
|
|
||||||
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
|
|
||||||
# the `test_attention_outputs` and `test_hidden_states_output` tests
|
|
||||||
self.encoder_seq_length = (
|
|
||||||
self.seq_length + (self.attention_window - self.seq_length % self.attention_window) % self.attention_window
|
|
||||||
)
|
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user