[tests] remove pt_tf equivalence tests (#36253)
This commit is contained in:
@@ -309,10 +309,6 @@ class Data2VecVisionModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
|
||||
# We override with a slightly higher tol value, as semseg models tend to diverge a bit more
|
||||
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
def test_for_image_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||
|
||||
@@ -385,10 +385,6 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, PipelineTesterMixin, unittes
|
||||
val_loss2 = history2.history["val_loss"][0]
|
||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||
|
||||
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
|
||||
# We override with a slightly higher tol value, as semseg models tend to diverge a bit more
|
||||
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
# Overriding this method since the base method won't be compatible with Data2VecVision.
|
||||
def test_loss_computation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user