Restore TF embeddings and attention layers to their previous version (#9890)

* Refacto BERT

* Restore all the concerned models

* Remove print

* Update template

* Apply Sylvain's and Morgan's comments

* Fix cast

* Put the cast inside call

* Remove cond in ebds

* Fix funnel

* Restore previous dot product (attention_scores) computation

* Add ConvBERT and BART

* Make all the S2S models ONNX compliant

* Fix test

* Fix check copies
This commit is contained in:
Julien Plu
2021-02-08 12:36:30 +01:00
committed by GitHub
parent 8bb52bd240
commit 31563e056d
20 changed files with 754 additions and 1966 deletions

View File

@@ -94,7 +94,8 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
# Copied from transformers.models.bart.modeling_tf_bart._expand_mask
@@ -102,10 +103,9 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = shape_list(mask)
src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.broadcast_to(mask[:, None, None, :], (bsz, 1, tgt_len, src_len)), tf.float32)
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
return (1.0 - expanded_mask) * LARGE_NEGATIVE