Apply ruff flake8-comprehensions (#21694)
This commit is contained in:
@@ -1643,7 +1643,7 @@ class ModelTesterMixin:
|
||||
params = dict(model_reloaded.named_parameters())
|
||||
params.update(dict(model_reloaded.named_buffers()))
|
||||
# param_names = set(k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys())
|
||||
param_names = set(k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys())
|
||||
param_names = {k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys()}
|
||||
|
||||
missing_keys = set(infos["missing_keys"])
|
||||
|
||||
@@ -1770,8 +1770,8 @@ class ModelTesterMixin:
|
||||
def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_class):
|
||||
"""For temporarily ignoring some failed test cases (issues to be fixed)"""
|
||||
|
||||
tf_keys = set([k for k, v in tf_outputs.items() if v is not None])
|
||||
pt_keys = set([k for k, v in pt_outputs.items() if v is not None])
|
||||
tf_keys = {k for k, v in tf_outputs.items() if v is not None}
|
||||
pt_keys = {k for k, v in pt_outputs.items() if v is not None}
|
||||
|
||||
key_differences = tf_keys.symmetric_difference(pt_keys)
|
||||
|
||||
@@ -2995,7 +2995,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
index = json.loads(f.read())
|
||||
|
||||
all_shards = set(index["weight_map"].values())
|
||||
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".bin"))
|
||||
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".bin")}
|
||||
self.assertSetEqual(all_shards, shards_found)
|
||||
|
||||
# Finally, check the model can be reloaded
|
||||
|
||||
Reference in New Issue
Block a user