Update quality tooling for formatting (#21480)
* Result of black 23.1 * Update target to Python 3.7 * Switch flake8 to ruff * Configure isort * Configure isort * Apply isort with line limit * Put the right black version * adapt black in check copies * Fix copies
This commit is contained in:
@@ -100,7 +100,6 @@ class VisionTextDualEncoderMixin:
|
||||
def check_vision_text_dual_encoder_from_pretrained(
|
||||
self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs
|
||||
):
|
||||
|
||||
vision_model, text_model = self.get_vision_text_model(vision_config, text_config)
|
||||
kwargs = {"vision_model": vision_model, "text_model": text_model}
|
||||
model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs)
|
||||
@@ -157,7 +156,6 @@ class VisionTextDualEncoderMixin:
|
||||
)
|
||||
|
||||
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
||||
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
@@ -199,7 +197,6 @@ class VisionTextDualEncoderMixin:
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
|
||||
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
@@ -211,7 +208,6 @@ class VisionTextDualEncoderMixin:
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict):
|
||||
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
@@ -239,7 +235,6 @@ class VisionTextDualEncoderMixin:
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_pt_flax_equivalence(self):
|
||||
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
vision_config = config_inputs_dict.pop("vision_config")
|
||||
text_config = config_inputs_dict.pop("text_config")
|
||||
@@ -311,7 +306,6 @@ class FlaxViTBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase):
|
||||
"vision_config": vision_config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": attention_mask,
|
||||
"text_config": text_config,
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
@@ -362,7 +356,6 @@ class FlaxCLIPVisionBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase)
|
||||
"vision_config": vision_config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": attention_mask,
|
||||
"text_config": text_config,
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
@@ -108,7 +108,6 @@ class VisionTextDualEncoderMixin:
|
||||
def check_vision_text_dual_encoder_from_pretrained(
|
||||
self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs
|
||||
):
|
||||
|
||||
vision_model, text_model = self.get_vision_text_model(vision_config, text_config)
|
||||
kwargs = {"vision_model": vision_model, "text_model": text_model}
|
||||
model = VisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs)
|
||||
@@ -175,7 +174,6 @@ class VisionTextDualEncoderMixin:
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mask, pixel_values, **kwargs):
|
||||
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
@@ -218,7 +216,6 @@ class VisionTextDualEncoderMixin:
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
|
||||
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
@@ -230,7 +227,6 @@ class VisionTextDualEncoderMixin:
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, **inputs_dict)
|
||||
|
||||
def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict):
|
||||
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
@@ -262,7 +258,6 @@ class VisionTextDualEncoderMixin:
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_pt_flax_equivalence(self):
|
||||
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
vision_config = config_inputs_dict.pop("vision_config")
|
||||
text_config = config_inputs_dict.pop("text_config")
|
||||
@@ -341,7 +336,6 @@ class ViTBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase):
|
||||
"vision_config": vision_config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": input_mask,
|
||||
"text_config": text_config,
|
||||
"input_ids": input_ids,
|
||||
"text_token_type_ids": token_type_ids,
|
||||
"text_sequence_labels": sequence_labels,
|
||||
@@ -429,7 +423,6 @@ class DeiTRobertaModelTest(VisionTextDualEncoderMixin, unittest.TestCase):
|
||||
"vision_config": vision_config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": input_mask,
|
||||
"text_config": text_config,
|
||||
"input_ids": input_ids,
|
||||
"text_token_type_ids": token_type_ids,
|
||||
"text_sequence_labels": sequence_labels,
|
||||
@@ -491,7 +484,6 @@ class CLIPVisionBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase):
|
||||
"vision_config": vision_config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": input_mask,
|
||||
"text_config": text_config,
|
||||
"input_ids": input_ids,
|
||||
"text_token_type_ids": token_type_ids,
|
||||
"text_sequence_labels": sequence_labels,
|
||||
|
||||
Reference in New Issue
Block a user