tests: fix pytorch tensor placement errors (#33485)
This commit fixes the following errors: * Fix "expected all tensors to be on the same device" error * Fix "can't convert device type tensor to numpy" According to pytorch documentation torch.Tensor.numpy(force=False) performs conversion only if tensor is on CPU (plus few other restrictions) which is not the case. For our case we need force=True since we just need a data and don't care about tensors coherency. Fixes: #33517 See: https://pytorch.org/docs/2.4/generated/torch.Tensor.numpy.html Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
@@ -160,7 +160,7 @@ class VisionTextDualEncoderMixin:
|
||||
|
||||
# prepare inputs
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
@@ -168,7 +168,7 @@ class VisionTextDualEncoderMixin:
|
||||
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -178,7 +178,7 @@ class VisionTextDualEncoderMixin:
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -193,7 +193,7 @@ class VisionTextDualEncoderMixin:
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
@@ -179,7 +179,7 @@ class VisionTextDualEncoderMixin:
|
||||
# prepare inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
|
||||
pt_inputs = inputs_dict
|
||||
flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()}
|
||||
flax_inputs = {k: v.numpy(force=True) for k, v in pt_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
@@ -187,7 +187,7 @@ class VisionTextDualEncoderMixin:
|
||||
fx_outputs = fx_model(**flax_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -197,7 +197,7 @@ class VisionTextDualEncoderMixin:
|
||||
fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -212,7 +212,7 @@ class VisionTextDualEncoderMixin:
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
Reference in New Issue
Block a user