Fix some TF slow tests (#9728)
* Fix saved model tests + fix a graph issue in longformer * Apply style
This commit is contained in:
@@ -2438,10 +2438,16 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
|
||||
logger.info("Initializing global attention on CLS token...")
|
||||
# global attention on cls token
|
||||
inputs["global_attention_mask"] = tf.zeros_like(inputs["input_ids"])
|
||||
updates = tf.ones(shape_list(inputs["input_ids"])[0], dtype=tf.int32)
|
||||
indices = tf.pad(
|
||||
tensor=tf.expand_dims(tf.range(shape_list(inputs["input_ids"])[0]), axis=1),
|
||||
paddings=[[0, 0], [0, 1]],
|
||||
constant_values=0,
|
||||
)
|
||||
inputs["global_attention_mask"] = tf.tensor_scatter_nd_update(
|
||||
inputs["global_attention_mask"],
|
||||
[[i, 0] for i in range(shape_list(inputs["input_ids"])[0])],
|
||||
[1 for _ in range(shape_list(inputs["input_ids"])[0])],
|
||||
indices,
|
||||
updates,
|
||||
)
|
||||
|
||||
outputs = self.longformer(
|
||||
|
||||
Reference in New Issue
Block a user