tests: fix pytorch tensor placement errors (#33485)

This commit fixes the following errors:
* Fix "expected all tensors to be on the same device" error
* Fix "can't convert device type tensor to numpy"

According to pytorch documentation torch.Tensor.numpy(force=False)
performs conversion only if tensor is on CPU (plus few other restrictions)
which is not the case. For our case we need force=True since we just
need a data and don't care about tensors coherency.

Fixes: #33517
See: https://pytorch.org/docs/2.4/generated/torch.Tensor.numpy.html

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
Dmitry Rogozhkin
2024-09-25 04:21:53 -07:00
committed by GitHub
parent 52daf4ec76
commit 5e2916bc14
8 changed files with 29 additions and 26 deletions

View File

@@ -160,7 +160,7 @@ class VisionTextDualEncoderMixin:
# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
@@ -168,7 +168,7 @@ class VisionTextDualEncoderMixin:
fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -178,7 +178,7 @@ class VisionTextDualEncoderMixin:
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -193,7 +193,7 @@ class VisionTextDualEncoderMixin:
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)

View File

@@ -179,7 +179,7 @@ class VisionTextDualEncoderMixin:
# prepare inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
pt_inputs = inputs_dict
flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()}
flax_inputs = {k: v.numpy(force=True) for k, v in pt_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
@@ -187,7 +187,7 @@ class VisionTextDualEncoderMixin:
fx_outputs = fx_model(**flax_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -197,7 +197,7 @@ class VisionTextDualEncoderMixin:
fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -212,7 +212,7 @@ class VisionTextDualEncoderMixin:
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)