Restore TF embeddings and attention layers to their previous version (#9890)
* Refacto BERT * Restore all the concerned models * Remove print * Update template * Apply Sylvain's and Morgan's comments * Fix cast * Put the cast inside call * Remove cond in ebds * Fix funnel * Restore previous dot product (attention_scores) computation * Add ConvBERT and BART * Make all the S2S models ONNX compliant * Fix test * Fix check copies
This commit is contained in:
@@ -94,7 +94,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
|
||||
return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
|
||||
|
||||
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
|
||||
@@ -102,10 +103,9 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = shape_list(mask)
|
||||
src_len = shape_list(mask)[1]
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = tf.cast(tf.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)), tf.float32)
|
||||
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
|
||||
|
||||
return (1.0 - expanded_mask) * LARGE_NEGATIVE
|
||||
|
||||
|
||||
Reference in New Issue
Block a user