This commit is contained in:
Patrick von Platen
2021-10-18 12:59:18 +02:00
committed by GitHub
parent 82b62fa607
commit 7c6cd0ac28

View File

@@ -480,8 +480,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
# PyTorch CLIPModel returns loss, we skip it here as we don't return loss in JAX/Flax models
pt_outputs = pt_outputs[1:]
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
@@ -525,8 +523,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
# PyTorch CLIPModel returns loss, we skip it here as we don't return loss in JAX/Flax models
pt_outputs = pt_outputs[1:]
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
@@ -539,7 +535,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
pt_outputs_loaded = pt_outputs_loaded[1:]
self.assertEqual( self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"