Update TF(Vision)EncoderDecoderModel PT/TF equivalence tests (#18073)
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1616,7 +1616,7 @@ class ModelTesterMixin:
|
||||
|
||||
# Copied from tests.test_modeling_tf_common.TFModelTesterMixin.check_pt_tf_outputs
|
||||
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
|
||||
"""Check the outputs from PyTorch and TensorFlow models are closed enough. Checks are done in a recursive way.
|
||||
"""Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way.
|
||||
|
||||
Args:
|
||||
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
|
||||
@@ -1642,8 +1642,8 @@ class ModelTesterMixin:
|
||||
# TODO: remove this method and this line after issues are fixed
|
||||
tf_outputs, pt_outputs = self._postprocessing_to_ignore_test_cases(tf_outputs, pt_outputs, model_class)
|
||||
|
||||
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
tf_keys = [k for k, v in tf_outputs.items() if v is not None]
|
||||
pt_keys = [k for k, v in pt_outputs.items() if v is not None]
|
||||
|
||||
self.assertEqual(tf_keys, pt_keys, f"{name}: Output keys differ between TF and PyTorch")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user