up (#14046)
This commit is contained in:
committed by
GitHub
parent
82b62fa607
commit
7c6cd0ac28
@@ -480,8 +480,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
# PyTorch CLIPModel returns loss, we skip it here as we don't return loss in JAX/Flax models
|
|
||||||
pt_outputs = pt_outputs[1:]
|
|
||||||
|
|
||||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
@@ -525,8 +523,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
# PyTorch CLIPModel returns loss, we skip it here as we don't return loss in JAX/Flax models
|
|
||||||
pt_outputs = pt_outputs[1:]
|
|
||||||
|
|
||||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
@@ -539,7 +535,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||||
pt_outputs_loaded = pt_outputs_loaded[1:]
|
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||||
|
|||||||
Reference in New Issue
Block a user