Black 20 release
This commit is contained in:
@@ -36,7 +36,7 @@ if is_tf_available():
|
||||
|
||||
def shape_list(x):
|
||||
"""
|
||||
copied from transformers.modeling_tf_utils
|
||||
copied from transformers.modeling_tf_utils
|
||||
"""
|
||||
static = x.shape.as_list()
|
||||
dynamic = tf.shape(x)
|
||||
@@ -45,7 +45,8 @@ if is_tf_available():
|
||||
|
||||
class TFLongformerModelTester:
|
||||
def __init__(
|
||||
self, parent,
|
||||
self,
|
||||
parent,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = 13
|
||||
@@ -228,7 +229,8 @@ class TFLongformerModelTester:
|
||||
# global attention mask has to be partly defined
|
||||
# to trace all weights
|
||||
global_attention_mask = tf.concat(
|
||||
[tf.zeros_like(input_ids)[:, :-1], tf.ones_like(input_ids)[:, -1:]], axis=-1,
|
||||
[tf.zeros_like(input_ids)[:, :-1], tf.ones_like(input_ids)[:, -1:]],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
inputs_dict = {
|
||||
@@ -267,7 +269,13 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
test_torchscript = False
|
||||
|
||||
all_model_classes = (
|
||||
(TFLongformerModel, TFLongformerForMaskedLM, TFLongformerForQuestionAnswering,) if is_tf_available() else ()
|
||||
(
|
||||
TFLongformerModel,
|
||||
TFLongformerForMaskedLM,
|
||||
TFLongformerForQuestionAnswering,
|
||||
)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user