From 7c6cd0ac28f1b760ccb4d6e4761f13185d05d90b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 18 Oct 2021 12:59:18 +0200 Subject: [PATCH] up (#14046) --- tests/test_modeling_flax_clip.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/test_modeling_flax_clip.py b/tests/test_modeling_flax_clip.py index cab4b7b53d..fabc5fd25a 100644 --- a/tests/test_modeling_flax_clip.py +++ b/tests/test_modeling_flax_clip.py @@ -480,8 +480,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase): with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() - # PyTorch CLIPModel returns loss, we skip it here as we don't return loss in JAX/Flax models - pt_outputs = pt_outputs[1:] fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") @@ -525,8 +523,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase): with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() - # PyTorch CLIPModel returns loss, we skip it here as we don't return loss in JAX/Flax models - pt_outputs = pt_outputs[1:] fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") @@ -539,7 +535,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase): with torch.no_grad(): pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() - pt_outputs_loaded = pt_outputs_loaded[1:] self.assertEqual( len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"