tests: fix pytorch tensor placement errors (#33485)
This commit fixes the following errors: * Fix "expected all tensors to be on the same device" error * Fix "can't convert device type tensor to numpy" According to pytorch documentation torch.Tensor.numpy(force=False) performs conversion only if tensor is on CPU (plus few other restrictions) which is not the case. For our case we need force=True since we just need a data and don't care about tensors coherency. Fixes: #33517 See: https://pytorch.org/docs/2.4/generated/torch.Tensor.numpy.html Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
@@ -241,7 +241,7 @@ class FlaxEncoderDecoderMixin:
|
||||
|
||||
# prepare inputs
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
@@ -249,7 +249,7 @@ class FlaxEncoderDecoderMixin:
|
||||
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -259,7 +259,7 @@ class FlaxEncoderDecoderMixin:
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -274,7 +274,7 @@ class FlaxEncoderDecoderMixin:
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
|
||||
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
||||
Reference in New Issue
Block a user