Apply ruff flake8-comprehensions (#21694)
This commit is contained in:
@@ -398,7 +398,7 @@ class TFModelTesterMixin:
|
||||
def test_keras_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
tf_main_layer_classes = set(
|
||||
tf_main_layer_classes = {
|
||||
module_member
|
||||
for model_class in self.all_model_classes
|
||||
for module in (import_module(model_class.__module__),)
|
||||
@@ -410,7 +410,7 @@ class TFModelTesterMixin:
|
||||
if isinstance(module_member, type)
|
||||
and tf.keras.layers.Layer in module_member.__bases__
|
||||
and getattr(module_member, "_keras_serializable", False)
|
||||
)
|
||||
}
|
||||
for main_layer_class in tf_main_layer_classes:
|
||||
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
|
||||
if "T5" in main_layer_class.__name__:
|
||||
@@ -498,8 +498,8 @@ class TFModelTesterMixin:
|
||||
def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_class):
|
||||
"""For temporarily ignoring some failed test cases (issues to be fixed)"""
|
||||
|
||||
tf_keys = set([k for k, v in tf_outputs.items() if v is not None])
|
||||
pt_keys = set([k for k, v in pt_outputs.items() if v is not None])
|
||||
tf_keys = {k for k, v in tf_outputs.items() if v is not None}
|
||||
pt_keys = {k for k, v in pt_outputs.items() if v is not None}
|
||||
|
||||
key_differences = tf_keys.symmetric_difference(pt_keys)
|
||||
|
||||
@@ -1455,7 +1455,7 @@ class TFModelTesterMixin:
|
||||
continue
|
||||
# The number of elements in the loss should be the same as the number of elements in the label
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
added_label_names = sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)
|
||||
added_label_names = sorted(prepared_for_class.keys() - inputs_dict.keys(), reverse=True)
|
||||
if not added_label_names:
|
||||
continue # This test is only for models with easily-separable labels
|
||||
added_label = prepared_for_class[added_label_names[0]]
|
||||
@@ -1713,7 +1713,7 @@ class TFModelTesterMixin:
|
||||
}
|
||||
|
||||
signature = inspect.signature(model.call)
|
||||
if set(head_masking.keys()) < set([*signature.parameters.keys()]):
|
||||
if set(head_masking.keys()) < {*signature.parameters.keys()}:
|
||||
continue
|
||||
|
||||
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
||||
@@ -2274,7 +2274,7 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
index = json.loads(f.read())
|
||||
|
||||
all_shards = set(index["weight_map"].values())
|
||||
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".h5"))
|
||||
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".h5")}
|
||||
self.assertSetEqual(all_shards, shards_found)
|
||||
|
||||
# Finally, check the model can be reloaded
|
||||
|
||||
Reference in New Issue
Block a user