Apply ruff flake8-comprehensions (#21694)

This commit is contained in:
Aaron Gokaslan
2023-02-22 03:14:54 -05:00
committed by GitHub
parent df06fb1f0b
commit 5e8c8eb5ba
230 changed files with 971 additions and 955 deletions

View File

@@ -1643,7 +1643,7 @@ class ModelTesterMixin:
params = dict(model_reloaded.named_parameters())
params.update(dict(model_reloaded.named_buffers()))
# param_names = set(k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys())
param_names = set(k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys())
param_names = {k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys()}
missing_keys = set(infos["missing_keys"])
@@ -1770,8 +1770,8 @@ class ModelTesterMixin:
def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_class):
"""For temporarily ignoring some failed test cases (issues to be fixed)"""
tf_keys = set([k for k, v in tf_outputs.items() if v is not None])
pt_keys = set([k for k, v in pt_outputs.items() if v is not None])
tf_keys = {k for k, v in tf_outputs.items() if v is not None}
pt_keys = {k for k, v in pt_outputs.items() if v is not None}
key_differences = tf_keys.symmetric_difference(pt_keys)
@@ -2995,7 +2995,7 @@ class ModelUtilsTest(TestCasePlus):
index = json.loads(f.read())
all_shards = set(index["weight_map"].values())
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".bin"))
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".bin")}
self.assertSetEqual(all_shards, shards_found)
# Finally, check the model can be reloaded