From 4cd7ed4b3b7360aef3a9fb16dfcc105001188717 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 3 Mar 2022 13:21:31 +0100 Subject: [PATCH] Fix a TF Vision Encoder Decoder test (#15896) * send PyTorch inputs to the correct device * Fix: TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first. Co-authored-by: ydshieh --- .../test_modeling_tf_vision_encoder_decoder.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py b/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py index 891c1aa55d..a0fcbfaea3 100644 --- a/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py +++ b/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py @@ -311,6 +311,9 @@ class TFVisionEncoderDecoderMixin: if "labels" in pt_inputs: pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor) + # send pytorch inputs to the correct device + pt_inputs = {k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()} + with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() @@ -321,7 +324,7 @@ class TFVisionEncoderDecoderMixin: self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch") for tf_output, pt_output in zip(tf_outputs, pt_outputs): - self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3) + self.assert_almost_equals(tf_output.numpy(), pt_output.detach().to("cpu").numpy(), 1e-3) # PT -> TF with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname: @@ -341,7 +344,7 @@ class TFVisionEncoderDecoderMixin: self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch") for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs): - self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3) + self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.detach().to("cpu").numpy(), 1e-3) def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):