[EncoderDecoder] Make tests more aggressive (#9256)
* add tests * make style and fix bart bug * fix bart past key value edge case * correct tf bart test * fix gpt2 tf * fix t5 test
This commit is contained in:
committed by
GitHub
parent
ec07da65e2
commit
e9d77ccd5a
@@ -95,9 +95,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
||||
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
||||
|
||||
|
||||
def _expand_mask(
|
||||
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None, past_key_values_length: int = 0
|
||||
):
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
@@ -106,16 +104,6 @@ def _expand_mask(
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
# concat fully attendend attention_mask to the beginning if `past_key_values` are used
|
||||
expanded_mask = torch.cat(
|
||||
[
|
||||
torch.ones(bsz, 1, tgt_len, past_key_values_length, device=expanded_mask.device, dtype=dtype),
|
||||
expanded_mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
@@ -941,11 +929,21 @@ class BartDecoder(BartPretrainedModel):
|
||||
attention_mask = input_ids.ne(self.config.pad_token_id).to(torch.long)
|
||||
# never mask leading token, even if it is pad
|
||||
attention_mask[:, 0] = attention_mask[:, 1]
|
||||
if past_key_values_length > 0:
|
||||
attention_mask = torch.cat(
|
||||
[
|
||||
torch.ones(
|
||||
(input_shape[0], past_key_values_length), dtype=torch.long, device=input_ids.device
|
||||
),
|
||||
attention_mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if attention_mask is not None and combined_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
# expand encoder attention mask
|
||||
|
||||
@@ -94,7 +94,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
|
||||
return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
|
||||
|
||||
|
||||
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
|
||||
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
@@ -103,16 +103,6 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
|
||||
|
||||
expanded_mask = tf.cast(tf.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)), tf.float32)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
# concat fully attendend attention_mask to the beginning if `past_key_values` are used
|
||||
expanded_mask = tf.concat(
|
||||
[
|
||||
tf.ones((bsz, 1, tgt_len, past_key_values_length), dtype=tf.float32),
|
||||
expanded_mask,
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
return (1.0 - expanded_mask) * LARGE_NEGATIVE
|
||||
|
||||
|
||||
@@ -902,14 +892,16 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
attention_mask = tf.cast(
|
||||
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype
|
||||
)
|
||||
attention_mask = tf.concat(
|
||||
[tf.ones((input_shape[0], past_key_values_length), dtype=attention_mask.dtype), attention_mask],
|
||||
axis=-1,
|
||||
)
|
||||
else:
|
||||
attention_mask = tf.ones(input_shape, dtype=tf.int32)
|
||||
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32)
|
||||
|
||||
if attention_mask is not None and combined_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||
attention_mask, past_key_values_length=past_key_values_length
|
||||
)
|
||||
combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
|
||||
|
||||
encoder_hidden_states = inputs["encoder_hidden_states"]
|
||||
if encoder_hidden_states is not None and inputs["encoder_attention_mask"] is not None:
|
||||
|
||||
Reference in New Issue
Block a user