Enable some ruff checks for performance and readability (#39383)
* Fix inefficient sequence tests Signed-off-by: cyy <cyyever@outlook.com> * Enable PERF102 Signed-off-by: cyy <cyyever@outlook.com> * Enable PLC1802 Signed-off-by: cyy <cyyever@outlook.com> * Enable PLC0208 Signed-off-by: cyy <cyyever@outlook.com> --------- Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -754,9 +754,9 @@ def get_parameters(model: nn.Module) -> Iterable[torch.Tensor]:
|
|||||||
Returns:
|
Returns:
|
||||||
Iterable[torch.Tensor]: An iterator over all parameters in the model
|
Iterable[torch.Tensor]: An iterator over all parameters in the model
|
||||||
"""
|
"""
|
||||||
for name, module in model._modules.items():
|
for module in model._modules.values():
|
||||||
# Look for parameters in module attributes
|
# Look for parameters in module attributes
|
||||||
for attr_name, attr in module.__dict__.items():
|
for attr in module.__dict__.values():
|
||||||
if isinstance(attr, torch.Tensor) and attr.requires_grad:
|
if isinstance(attr, torch.Tensor) and attr.requires_grad:
|
||||||
yield attr
|
yield attr
|
||||||
# Recursively get parameters from submodules
|
# Recursively get parameters from submodules
|
||||||
|
|||||||
@@ -19,10 +19,10 @@ line-length = 119
|
|||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
# Never enforce `E501` (line length violations).
|
# Never enforce `E501` (line length violations).
|
||||||
ignore = ["C901", "E501", "E741", "F402", "F823" ]
|
ignore = ["C901", "E501", "E741", "F402", "F823"]
|
||||||
# RUF013: Checks for the use of implicit Optional
|
# RUF013: Checks for the use of implicit Optional
|
||||||
# in type annotations when the default parameter value is None.
|
# in type annotations when the default parameter value is None.
|
||||||
select = ["C", "E", "F", "I", "W", "RUF013", "UP006"]
|
select = ["C", "E", "F", "I", "W", "RUF013", "UP006", "PERF102", "PLC1802", "PLC0208"]
|
||||||
extend-safe-fixes = ["UP006"]
|
extend-safe-fixes = ["UP006"]
|
||||||
|
|
||||||
# Ignore import violations in all `__init__.py` files.
|
# Ignore import violations in all `__init__.py` files.
|
||||||
|
|||||||
@@ -607,7 +607,7 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||||
# sometimes the config has no `base_config_key` if the config is used in several composite models
|
# sometimes the config has no `base_config_key` if the config is used in several composite models
|
||||||
# e.g. LlamaConfig. In that case we try to see if there is match in `model_type` before raising a warning
|
# e.g. LlamaConfig. In that case we try to see if there is match in `model_type` before raising a warning
|
||||||
for k, v in config_dict.items():
|
for v in config_dict.values():
|
||||||
if isinstance(v, dict) and v.get("model_type") == cls.model_type:
|
if isinstance(v, dict) and v.get("model_type") == cls.model_type:
|
||||||
config_dict = v
|
config_dict = v
|
||||||
|
|
||||||
|
|||||||
@@ -2166,7 +2166,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
|
||||||
|
|
||||||
if self._tp_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
if self._tp_plan is not None and is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||||
for _, v in self._tp_plan.items():
|
for v in self._tp_plan.values():
|
||||||
if v not in ALL_PARALLEL_STYLES:
|
if v not in ALL_PARALLEL_STYLES:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}"
|
f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}"
|
||||||
@@ -2845,7 +2845,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
|
|
||||||
all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules.keys()}
|
all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules.keys()}
|
||||||
encoder_layer_pos = 0
|
encoder_layer_pos = 0
|
||||||
for name, module in decoder_modules.items():
|
for name in decoder_modules.keys():
|
||||||
if name.isdigit():
|
if name.isdigit():
|
||||||
encoder_name = str(int(name) + encoder_layer_pos)
|
encoder_name = str(int(name) + encoder_layer_pos)
|
||||||
decoder_name = name
|
decoder_name = name
|
||||||
@@ -5830,7 +5830,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
|||||||
accelerator_device_map = {
|
accelerator_device_map = {
|
||||||
param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
|
param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
|
||||||
}
|
}
|
||||||
if not len(accelerator_device_map):
|
if not accelerator_device_map:
|
||||||
return
|
return
|
||||||
|
|
||||||
tp_plan_regex = (
|
tp_plan_regex = (
|
||||||
|
|||||||
@@ -133,7 +133,7 @@ def feature_extractor_class_from_name(class_name: str):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for _, extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.items():
|
for extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.values():
|
||||||
if getattr(extractor, "__name__", None) == class_name:
|
if getattr(extractor, "__name__", None) == class_name:
|
||||||
return extractor
|
return extractor
|
||||||
|
|
||||||
|
|||||||
@@ -212,7 +212,7 @@ def get_image_processor_class_from_name(class_name: str):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for _, extractors in IMAGE_PROCESSOR_MAPPING._extra_content.items():
|
for extractors in IMAGE_PROCESSOR_MAPPING._extra_content.values():
|
||||||
for extractor in extractors:
|
for extractor in extractors:
|
||||||
if getattr(extractor, "__name__", None) == class_name:
|
if getattr(extractor, "__name__", None) == class_name:
|
||||||
return extractor
|
return extractor
|
||||||
@@ -533,7 +533,7 @@ class AutoImageProcessor:
|
|||||||
)
|
)
|
||||||
use_fast = False
|
use_fast = False
|
||||||
if use_fast:
|
if use_fast:
|
||||||
for _, image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
|
for image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.values():
|
||||||
if image_processor_type in image_processors:
|
if image_processor_type in image_processors:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -744,7 +744,7 @@ def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]:
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for config, tokenizers in TOKENIZER_MAPPING._extra_content.items():
|
for tokenizers in TOKENIZER_MAPPING._extra_content.values():
|
||||||
for tokenizer in tokenizers:
|
for tokenizer in tokenizers:
|
||||||
if getattr(tokenizer, "__name__", None) == class_name:
|
if getattr(tokenizer, "__name__", None) == class_name:
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ def video_processor_class_from_name(class_name: str):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for _, extractor in VIDEO_PROCESSOR_MAPPING._extra_content.items():
|
for extractor in VIDEO_PROCESSOR_MAPPING._extra_content.values():
|
||||||
if getattr(extractor, "__name__", None) == class_name:
|
if getattr(extractor, "__name__", None) == class_name:
|
||||||
return extractor
|
return extractor
|
||||||
|
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ class BridgeTowerResidualAttention(nn.Module):
|
|||||||
def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
|
def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
|
||||||
residual_state = hidden_state + self.attention(self.ln_1(hidden_state), attention_mask)
|
residual_state = hidden_state + self.attention(self.ln_1(hidden_state), attention_mask)
|
||||||
hidden_state = self.ln_2(residual_state)
|
hidden_state = self.ln_2(residual_state)
|
||||||
for _, layer in self.mlp.items():
|
for layer in self.mlp.values():
|
||||||
hidden_state = layer(hidden_state)
|
hidden_state = layer(hidden_state)
|
||||||
hidden_state = residual_state + hidden_state
|
hidden_state = residual_state + hidden_state
|
||||||
return hidden_state
|
return hidden_state
|
||||||
|
|||||||
@@ -199,7 +199,7 @@ class DonutProcessor(ProcessorMixin):
|
|||||||
if tokens[:6] == r"<sep/>": # non-leaf nodes
|
if tokens[:6] == r"<sep/>": # non-leaf nodes
|
||||||
return [output] + self.token2json(tokens[6:], is_inner_value=True, added_vocab=added_vocab)
|
return [output] + self.token2json(tokens[6:], is_inner_value=True, added_vocab=added_vocab)
|
||||||
|
|
||||||
if len(output):
|
if output:
|
||||||
return [output] if is_inner_value else output
|
return [output] if is_inner_value else output
|
||||||
else:
|
else:
|
||||||
return [] if is_inner_value else {"text_sequence": tokens}
|
return [] if is_inner_value else {"text_sequence": tokens}
|
||||||
|
|||||||
@@ -239,7 +239,7 @@ def create_rename_keys(state_dict, config):
|
|||||||
########################################## DECODER - END
|
########################################## DECODER - END
|
||||||
|
|
||||||
########################################## Additional - START
|
########################################## Additional - START
|
||||||
for layer_name, params in state_dict.items():
|
for layer_name in state_dict.keys():
|
||||||
#### TEXT BACKBONE
|
#### TEXT BACKBONE
|
||||||
if "bert" in layer_name:
|
if "bert" in layer_name:
|
||||||
rename_keys.append((layer_name, layer_name.replace("bert", "model.text_backbone")))
|
rename_keys.append((layer_name, layer_name.replace("bert", "model.text_backbone")))
|
||||||
|
|||||||
@@ -177,7 +177,7 @@ def find_supported_resolutions(max_num_chunks: int, patch_size: SizeDict) -> tor
|
|||||||
|
|
||||||
# get the resolutions multiplied by the patch_size
|
# get the resolutions multiplied by the patch_size
|
||||||
possible_resolutions = []
|
possible_resolutions = []
|
||||||
for key, value in asp_dict.items():
|
for value in asp_dict.values():
|
||||||
for height, depth in value:
|
for height, depth in value:
|
||||||
possible_resolutions.append((height * patch_size, depth * patch_size))
|
possible_resolutions.append((height * patch_size, depth * patch_size))
|
||||||
|
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p
|
|||||||
state_dict.pop("lm_head.decoder.weight")
|
state_dict.pop("lm_head.decoder.weight")
|
||||||
state_dict.pop("lm_head.decoder.bias")
|
state_dict.pop("lm_head.decoder.bias")
|
||||||
state_dict_for_hugging_face = OrderedDict()
|
state_dict_for_hugging_face = OrderedDict()
|
||||||
for key, value in state_dict.items():
|
for key in state_dict.keys():
|
||||||
if not (key.startswith("lm_head") or key.startswith("entity_predictions")):
|
if not (key.startswith("lm_head") or key.startswith("entity_predictions")):
|
||||||
state_dict_for_hugging_face[f"luke.{key}"] = state_dict[key]
|
state_dict_for_hugging_face[f"luke.{key}"] = state_dict[key]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ def create_rename_keys_vision(state_dict, config):
|
|||||||
########################################## VISION BACKBONE - END
|
########################################## VISION BACKBONE - END
|
||||||
|
|
||||||
########################################## ENCODER - START
|
########################################## ENCODER - START
|
||||||
for layer_name, params in state_dict.items():
|
for layer_name in state_dict.keys():
|
||||||
if "neck" in layer_name:
|
if "neck" in layer_name:
|
||||||
layer_name_replace = layer_name.replace("neck", "encoder")
|
layer_name_replace = layer_name.replace("neck", "encoder")
|
||||||
layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers")
|
layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers")
|
||||||
@@ -117,7 +117,7 @@ def create_rename_keys_vision(state_dict, config):
|
|||||||
########################################## ENCODER - END
|
########################################## ENCODER - END
|
||||||
|
|
||||||
########################################## DECODER - START
|
########################################## DECODER - START
|
||||||
for layer_name, params in state_dict.items():
|
for layer_name in state_dict.keys():
|
||||||
if layer_name.startswith("decoder"):
|
if layer_name.startswith("decoder"):
|
||||||
layer_name_replace = layer_name.replace("decoder.decoder.layers", "decoder.layers")
|
layer_name_replace = layer_name.replace("decoder.decoder.layers", "decoder.layers")
|
||||||
layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers")
|
layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers")
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ def convert_xmod_checkpoint_to_pytorch(
|
|||||||
|
|
||||||
if sorted(bert_output.adapter_modules.keys()) != sorted(xmod_layer.adapter_modules.keys()):
|
if sorted(bert_output.adapter_modules.keys()) != sorted(xmod_layer.adapter_modules.keys()):
|
||||||
raise AssertionError("Lists of language adapters do not match.")
|
raise AssertionError("Lists of language adapters do not match.")
|
||||||
for lang_code, adapter in xmod_layer.adapter_modules.items():
|
for lang_code in xmod_layer.adapter_modules.keys():
|
||||||
to_adapter = bert_output.adapter_modules[lang_code]
|
to_adapter = bert_output.adapter_modules[lang_code]
|
||||||
from_adapter = xmod_layer.adapter_modules[lang_code]
|
from_adapter = xmod_layer.adapter_modules[lang_code]
|
||||||
to_adapter.dense1.weight = from_adapter.fc1.weight
|
to_adapter.dense1.weight = from_adapter.fc1.weight
|
||||||
|
|||||||
@@ -266,7 +266,7 @@ def convert_state_dict(orig_state_dict):
|
|||||||
|
|
||||||
|
|
||||||
def remove_ignore_keys(state_dict):
|
def remove_ignore_keys(state_dict):
|
||||||
for key, _ in state_dict.copy().items():
|
for key in state_dict.copy().keys():
|
||||||
if (
|
if (
|
||||||
"fc_norm" in key
|
"fc_norm" in key
|
||||||
or "relative_position_index" in key
|
or "relative_position_index" in key
|
||||||
|
|||||||
@@ -1288,14 +1288,14 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
|||||||
if self.task in SUPPORTED_PEFT_TASKS:
|
if self.task in SUPPORTED_PEFT_TASKS:
|
||||||
supported_models_names.extend(SUPPORTED_PEFT_TASKS[self.task])
|
supported_models_names.extend(SUPPORTED_PEFT_TASKS[self.task])
|
||||||
|
|
||||||
for _, model_name in supported_models.items():
|
for model_name in supported_models.values():
|
||||||
# Mapping can now contain tuples of models for the same configuration.
|
# Mapping can now contain tuples of models for the same configuration.
|
||||||
if isinstance(model_name, tuple):
|
if isinstance(model_name, tuple):
|
||||||
supported_models_names.extend(list(model_name))
|
supported_models_names.extend(list(model_name))
|
||||||
else:
|
else:
|
||||||
supported_models_names.append(model_name)
|
supported_models_names.append(model_name)
|
||||||
if hasattr(supported_models, "_model_mapping"):
|
if hasattr(supported_models, "_model_mapping"):
|
||||||
for _, model in supported_models._model_mapping._extra_content.items():
|
for model in supported_models._model_mapping._extra_content.values():
|
||||||
if isinstance(model_name, tuple):
|
if isinstance(model_name, tuple):
|
||||||
supported_models_names.extend([m.__name__ for m in model])
|
supported_models_names.extend([m.__name__ for m in model])
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -232,7 +232,7 @@ class BatchEncoding(UserDict):
|
|||||||
|
|
||||||
self._encodings = encoding
|
self._encodings = encoding
|
||||||
|
|
||||||
if n_sequences is None and encoding is not None and len(encoding):
|
if n_sequences is None and encoding is not None and encoding:
|
||||||
n_sequences = encoding[0].n_sequences
|
n_sequences = encoding[0].n_sequences
|
||||||
|
|
||||||
self._n_sequences = n_sequences
|
self._n_sequences = n_sequences
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ def find_batch_size(tensors):
|
|||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
elif isinstance(tensors, Mapping):
|
elif isinstance(tensors, Mapping):
|
||||||
for key, value in tensors.items():
|
for value in tensors.values():
|
||||||
result = find_batch_size(value)
|
result = find_batch_size(value)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -2183,12 +2183,12 @@ class _LazyModule(ModuleType):
|
|||||||
self._modules = self._modules.union(module_keys)
|
self._modules = self._modules.union(module_keys)
|
||||||
|
|
||||||
for key, values in module.items():
|
for key, values in module.items():
|
||||||
if len(missing_backends):
|
if missing_backends:
|
||||||
self._object_missing_backend[key] = missing_backends
|
self._object_missing_backend[key] = missing_backends
|
||||||
|
|
||||||
for value in values:
|
for value in values:
|
||||||
self._class_to_module[value] = key
|
self._class_to_module[value] = key
|
||||||
if len(missing_backends):
|
if missing_backends:
|
||||||
self._object_missing_backend[value] = missing_backends
|
self._object_missing_backend[value] = missing_backends
|
||||||
_import_structure.setdefault(key, []).extend(values)
|
_import_structure.setdefault(key, []).extend(values)
|
||||||
|
|
||||||
|
|||||||
@@ -1199,7 +1199,7 @@ class VptqConfig(QuantizationConfigMixin):
|
|||||||
r"""
|
r"""
|
||||||
Safety checker that arguments are correct
|
Safety checker that arguments are correct
|
||||||
"""
|
"""
|
||||||
for layer_name, layer_param in self.config_for_layers.items():
|
for layer_param in self.config_for_layers.values():
|
||||||
VptqLayerConfig(**layer_param)
|
VptqLayerConfig(**layer_param)
|
||||||
if self.enable_proxy_error is True:
|
if self.enable_proxy_error is True:
|
||||||
raise ValueError("enable_proxy_error should always be False until we support training")
|
raise ValueError("enable_proxy_error should always be False until we support training")
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
self.assertIsInstance(model, BertForPreTraining)
|
self.assertIsInstance(model, BertForPreTraining)
|
||||||
# Only one value should not be initialized and in the missing keys.
|
# Only one value should not be initialized and in the missing keys.
|
||||||
for key, value in loading_info.items():
|
for value in loading_info.values():
|
||||||
self.assertEqual(len(value), 0)
|
self.assertEqual(len(value), 0)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_tokenizer_from_pretrained(self):
|
def test_tokenizer_from_pretrained(self):
|
||||||
for model_name in {"google-bert/bert-base-uncased", "google-bert/bert-base-cased"}:
|
for model_name in ("google-bert/bert-base-uncased", "google-bert/bert-base-cased"):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(tokenizer)
|
self.assertIsNotNone(tokenizer)
|
||||||
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
|
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
|
||||||
|
|||||||
@@ -897,7 +897,7 @@ class LukeModelIntegrationTests(unittest.TestCase):
|
|||||||
encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True, return_tensors="pt")
|
encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True, return_tensors="pt")
|
||||||
|
|
||||||
# move all values to device
|
# move all values to device
|
||||||
for key, value in encoding.items():
|
for key in encoding.keys():
|
||||||
encoding[key] = encoding[key].to(torch_device)
|
encoding[key] = encoding[key].to(torch_device)
|
||||||
|
|
||||||
outputs = model(**encoding)
|
outputs = model(**encoding)
|
||||||
@@ -932,7 +932,7 @@ class LukeModelIntegrationTests(unittest.TestCase):
|
|||||||
encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True, return_tensors="pt")
|
encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True, return_tensors="pt")
|
||||||
|
|
||||||
# move all values to device
|
# move all values to device
|
||||||
for key, value in encoding.items():
|
for key in encoding.keys():
|
||||||
encoding[key] = encoding[key].to(torch_device)
|
encoding[key] = encoding[key].to(torch_device)
|
||||||
|
|
||||||
outputs = model(**encoding)
|
outputs = model(**encoding)
|
||||||
|
|||||||
@@ -757,7 +757,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
model.load_adapter(tmpdirname, is_trainable=True)
|
model.load_adapter(tmpdirname, is_trainable=True)
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if len(list(module.children())):
|
if list(module.children()):
|
||||||
# only check leaf modules
|
# only check leaf modules
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -2535,7 +2535,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
shared_ptrs = {k: v for k, v in ptrs.items() if len(v) > 1}
|
shared_ptrs = {k: v for k, v in ptrs.items() if len(v) > 1}
|
||||||
|
|
||||||
for _, shared_names in shared_ptrs.items():
|
for shared_names in shared_ptrs.values():
|
||||||
reloaded_ptrs = {reloaded_state[k].data_ptr() for k in shared_names}
|
reloaded_ptrs = {reloaded_state[k].data_ptr() for k in shared_names}
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
len(reloaded_ptrs),
|
len(reloaded_ptrs),
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ task_to_pipeline_and_spec_mapping = {
|
|||||||
"zero-shot-image-classification": (ZeroShotImageClassificationPipeline, ZeroShotImageClassificationInput),
|
"zero-shot-image-classification": (ZeroShotImageClassificationPipeline, ZeroShotImageClassificationInput),
|
||||||
}
|
}
|
||||||
|
|
||||||
for task, task_info in pipeline_test_mapping.items():
|
for task_info in pipeline_test_mapping.values():
|
||||||
test = task_info["test"]
|
test = task_info["test"]
|
||||||
task_info["mapping"] = {
|
task_info["mapping"] = {
|
||||||
"pt": getattr(test, "model_mapping", None),
|
"pt": getattr(test, "model_mapping", None),
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class TestImportStructures(unittest.TestCase):
|
|||||||
with self.subTest(f"Testing arch {architecture}"):
|
with self.subTest(f"Testing arch {architecture}"):
|
||||||
import_structure = define_import_structure(self.models_path / architecture)
|
import_structure = define_import_structure(self.models_path / architecture)
|
||||||
backend_agnostic_import_structure = {}
|
backend_agnostic_import_structure = {}
|
||||||
for requirement, module_object_mapping in import_structure.items():
|
for module_object_mapping in import_structure.values():
|
||||||
for module, objects in module_object_mapping.items():
|
for module, objects in module_object_mapping.items():
|
||||||
if module not in backend_agnostic_import_structure:
|
if module not in backend_agnostic_import_structure:
|
||||||
backend_agnostic_import_structure[module] = []
|
backend_agnostic_import_structure[module] = []
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ from tests.test_pipeline_mixin import pipeline_test_mapping
|
|||||||
|
|
||||||
|
|
||||||
PIPELINE_TEST_MAPPING = {}
|
PIPELINE_TEST_MAPPING = {}
|
||||||
for task, _ in pipeline_test_mapping.items():
|
for task in pipeline_test_mapping.keys():
|
||||||
PIPELINE_TEST_MAPPING[task] = {"pt": None, "tf": None}
|
PIPELINE_TEST_MAPPING[task] = {"pt": None, "tf": None}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -790,7 +790,7 @@ def check_all_auto_object_names_being_defined():
|
|||||||
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
|
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
|
||||||
|
|
||||||
for name, mapping in mappings_to_check.items():
|
for name, mapping in mappings_to_check.items():
|
||||||
for _, class_names in mapping.items():
|
for class_names in mapping.values():
|
||||||
if not isinstance(class_names, tuple):
|
if not isinstance(class_names, tuple):
|
||||||
class_names = (class_names,)
|
class_names = (class_names,)
|
||||||
for class_name in class_names:
|
for class_name in class_names:
|
||||||
|
|||||||
@@ -332,7 +332,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
doc_test_results = {}
|
doc_test_results = {}
|
||||||
# `artifact_key` is the artifact path
|
# `artifact_key` is the artifact path
|
||||||
for artifact_key, artifact_obj in available_artifacts.items():
|
for artifact_obj in available_artifacts.values():
|
||||||
artifact_path = artifact_obj.paths[0]
|
artifact_path = artifact_obj.paths[0]
|
||||||
if not artifact_path["path"].startswith("doc_tests_gpu_test_reports_"):
|
if not artifact_path["path"].startswith("doc_tests_gpu_test_reports_"):
|
||||||
continue
|
continue
|
||||||
|
|||||||
Reference in New Issue
Block a user