From a2c8e516c2fb55cf37844d434ae04b5abcfbc81b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 9 Mar 2020 19:47:45 +0100 Subject: [PATCH] fix torch to tf translation --- src/transformers/modeling_tf_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 3feafcbfa3..dffed85cdb 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -641,7 +641,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): # create attention mask if necessary # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 - if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids): + import ipdb + + ipdb.set_trace() + if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids.numpy()): attention_mask = tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32) elif attention_mask is None: attention_mask = tf.ones_like(input_ids)