up (#14046)
This commit is contained in:
committed by
GitHub
parent
82b62fa607
commit
7c6cd0ac28
@@ -480,8 +480,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
# PyTorch CLIPModel returns loss, we skip it here as we don't return loss in JAX/Flax models
|
||||
pt_outputs = pt_outputs[1:]
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
@@ -525,8 +523,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
# PyTorch CLIPModel returns loss, we skip it here as we don't return loss in JAX/Flax models
|
||||
pt_outputs = pt_outputs[1:]
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
@@ -539,7 +535,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
pt_outputs_loaded = pt_outputs_loaded[1:]
|
||||
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
|
||||
Reference in New Issue
Block a user