[EncoderDecoder] Make tests more aggressive (#9256)

* add tests

* make style and fix bart bug

* fix bart past key value edge case

* correct tf bart test

* fix gpt2 tf

* fix t5 test
This commit is contained in:
Patrick von Platen
2020-12-22 17:00:04 +01:00
committed by GitHub
parent ec07da65e2
commit e9d77ccd5a
9 changed files with 95 additions and 59 deletions

View File

@@ -95,9 +95,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _expand_mask(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None, past_key_values_length: int = 0
):
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@@ -106,16 +104,6 @@ def _expand_mask(
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
if past_key_values_length > 0:
# concat fully attendend attention_mask to the beginning if `past_key_values` are used
expanded_mask = torch.cat(
[
torch.ones(bsz, 1, tgt_len, past_key_values_length, device=expanded_mask.device, dtype=dtype),
expanded_mask,
],
dim=-1,
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
@@ -941,11 +929,21 @@ class BartDecoder(BartPretrainedModel):
attention_mask = input_ids.ne(self.config.pad_token_id).to(torch.long)
# never mask leading token, even if it is pad
attention_mask[:, 0] = attention_mask[:, 1]
if past_key_values_length > 0:
attention_mask = torch.cat(
[
torch.ones(
(input_shape[0], past_key_values_length), dtype=torch.long, device=input_ids.device
),
attention_mask,
],
dim=-1,
)
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, past_key_values_length=past_key_values_length
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# expand encoder attention mask

View File

@@ -94,7 +94,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
@@ -103,16 +103,6 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
expanded_mask = tf.cast(tf.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)), tf.float32)
if past_key_values_length > 0:
# concat fully attendend attention_mask to the beginning if `past_key_values` are used
expanded_mask = tf.concat(
[
tf.ones((bsz, 1, tgt_len, past_key_values_length), dtype=tf.float32),
expanded_mask,
],
axis=-1,
)
return (1.0 - expanded_mask) * LARGE_NEGATIVE
@@ -902,14 +892,16 @@ class TFBartDecoder(tf.keras.layers.Layer):
attention_mask = tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype
)
attention_mask = tf.concat(
[tf.ones((input_shape[0], past_key_values_length), dtype=attention_mask.dtype), attention_mask],
axis=-1,
)
else:
attention_mask = tf.ones(input_shape, dtype=tf.int32)
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32)
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, past_key_values_length=past_key_values_length
)
combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
encoder_hidden_states = inputs["encoder_hidden_states"]
if encoder_hidden_states is not None and inputs["encoder_attention_mask"] is not None: