tests: fix pytorch tensor placement errors (#33485)

This commit fixes the following errors:
* Fix "expected all tensors to be on the same device" error
* Fix "can't convert device type tensor to numpy"

According to pytorch documentation torch.Tensor.numpy(force=False)
performs conversion only if tensor is on CPU (plus few other restrictions)
which is not the case. For our case we need force=True since we just
need a data and don't care about tensors coherency.

Fixes: #33517
See: https://pytorch.org/docs/2.4/generated/torch.Tensor.numpy.html

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
Dmitry Rogozhkin
2024-09-25 04:21:53 -07:00
committed by GitHub
parent 52daf4ec76
commit 5e2916bc14
8 changed files with 29 additions and 26 deletions

View File

@@ -412,7 +412,7 @@ class FlaxEncoderDecoderMixin:
# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
@@ -420,7 +420,7 @@ class FlaxEncoderDecoderMixin:
fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -430,7 +430,7 @@ class FlaxEncoderDecoderMixin:
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -445,7 +445,7 @@ class FlaxEncoderDecoderMixin:
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)