Switch return_dict to True by default. (#8530)
* Use the CI to identify failing tests * Remove from all examples and tests * More default switch * Fixes * More test fixes * More fixes * Last fixes hopefully * Use the CI to identify failing tests * Remove from all examples and tests * More default switch * Fixes * More test fixes * More fixes * Last fixes hopefully * Run on the real suite * Fix slow tests
This commit is contained in:
@@ -29,7 +29,7 @@ class FlaxBertModelTest(unittest.TestCase):
|
||||
# Check for simple input
|
||||
pt_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.PYTORCH)
|
||||
fx_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.JAX)
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
|
||||
Reference in New Issue
Block a user