Apply ruff flake8-comprehensions (#21694)

This commit is contained in:
Aaron Gokaslan
2023-02-22 03:14:54 -05:00
committed by GitHub
parent df06fb1f0b
commit 5e8c8eb5ba
230 changed files with 971 additions and 955 deletions

View File

@@ -398,7 +398,7 @@ class TFModelTesterMixin:
def test_keras_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
tf_main_layer_classes = set(
tf_main_layer_classes = {
module_member
for model_class in self.all_model_classes
for module in (import_module(model_class.__module__),)
@@ -410,7 +410,7 @@ class TFModelTesterMixin:
if isinstance(module_member, type)
and tf.keras.layers.Layer in module_member.__bases__
and getattr(module_member, "_keras_serializable", False)
)
}
for main_layer_class in tf_main_layer_classes:
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
if "T5" in main_layer_class.__name__:
@@ -498,8 +498,8 @@ class TFModelTesterMixin:
def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_class):
"""For temporarily ignoring some failed test cases (issues to be fixed)"""
tf_keys = set([k for k, v in tf_outputs.items() if v is not None])
pt_keys = set([k for k, v in pt_outputs.items() if v is not None])
tf_keys = {k for k, v in tf_outputs.items() if v is not None}
pt_keys = {k for k, v in pt_outputs.items() if v is not None}
key_differences = tf_keys.symmetric_difference(pt_keys)
@@ -1455,7 +1455,7 @@ class TFModelTesterMixin:
continue
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
added_label_names = sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)
added_label_names = sorted(prepared_for_class.keys() - inputs_dict.keys(), reverse=True)
if not added_label_names:
continue # This test is only for models with easily-separable labels
added_label = prepared_for_class[added_label_names[0]]
@@ -1713,7 +1713,7 @@ class TFModelTesterMixin:
}
signature = inspect.signature(model.call)
if set(head_masking.keys()) < set([*signature.parameters.keys()]):
if set(head_masking.keys()) < {*signature.parameters.keys()}:
continue
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
@@ -2274,7 +2274,7 @@ class UtilsFunctionsTest(unittest.TestCase):
index = json.loads(f.read())
all_shards = set(index["weight_map"].values())
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".h5"))
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".h5")}
self.assertSetEqual(all_shards, shards_found)
# Finally, check the model can be reloaded