Black preview (#17217)
* Black preview * Fixup too! * Fix check copies * Use the same version as the CI * Bump black
This commit is contained in:
@@ -505,7 +505,8 @@ class TFModelTesterMixin:
|
||||
self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
|
||||
"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got"
|
||||
f" {type(tf_outputs)} instead."
|
||||
)
|
||||
|
||||
def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict):
|
||||
@@ -956,7 +957,10 @@ class TFModelTesterMixin:
|
||||
else:
|
||||
self.assertTrue(
|
||||
all(tf.equal(tuple_object, dict_object)),
|
||||
msg=f"Tuple and dict output are not equal. Difference: {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}",
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}"
|
||||
),
|
||||
)
|
||||
|
||||
recursive_check(tuple_output, dict_output)
|
||||
|
||||
Reference in New Issue
Block a user