tests: fix pytorch tensor placement errors (#33485)

This commit fixes the following errors:
* Fix "expected all tensors to be on the same device" error
* Fix "can't convert device type tensor to numpy"

According to pytorch documentation torch.Tensor.numpy(force=False)
performs conversion only if tensor is on CPU (plus few other restrictions)
which is not the case. For our case we need force=True since we just
need a data and don't care about tensors coherency.

Fixes: #33517
See: https://pytorch.org/docs/2.4/generated/torch.Tensor.numpy.html

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
Dmitry Rogozhkin
2024-09-25 04:21:53 -07:00
committed by GitHub
parent 52daf4ec76
commit 5e2916bc14
8 changed files with 29 additions and 26 deletions

View File

@@ -163,7 +163,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision # numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
if v.dtype == bfloat16: if v.dtype == bfloat16:
v = v.float() v = v.float()
pt_state_dict[k] = v.numpy() pt_state_dict[k] = v.cpu().numpy()
model_prefix = flax_model.base_model_prefix model_prefix = flax_model.base_model_prefix

View File

@@ -848,6 +848,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
with self.subTest(model_class.__name__): with self.subTest(model_class.__name__):
# load PyTorch class # load PyTorch class
pt_model = model_class(config).eval() pt_model = model_class(config).eval()
pt_model.to(torch_device)
# Flax models don't use the `use_cache` option and cache is not returned as a default. # Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model. # So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False pt_model.config.use_cache = False
@@ -881,7 +882,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
fx_outputs = fx_model(**fx_inputs).to_tuple() fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
@@ -892,7 +893,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
) )
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]): for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
# overwrite from common since FlaxCLIPModel returns nested output # overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test # which is not supported in the common test
@@ -921,6 +922,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
pt_model.to(torch_device)
# make sure weights are tied in PyTorch # make sure weights are tied in PyTorch
pt_model.tie_weights() pt_model.tie_weights()
@@ -940,11 +942,12 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname) fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
with torch.no_grad(): with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
@@ -953,7 +956,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
) )
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]): for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):

View File

@@ -297,7 +297,7 @@ class FlaxEncoderDecoderMixin:
# prepare inputs # prepare inputs
flax_inputs = inputs_dict flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
@@ -305,7 +305,7 @@ class FlaxEncoderDecoderMixin:
fx_outputs = fx_model(**inputs_dict).to_tuple() fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs): for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5) self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
# PT -> Flax # PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
@@ -315,7 +315,7 @@ class FlaxEncoderDecoderMixin:
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple() fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5) self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
# Flax -> PT # Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
@@ -330,7 +330,7 @@ class FlaxEncoderDecoderMixin:
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded): for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5) self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict): def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)

View File

@@ -170,7 +170,7 @@ class InformerModelTester:
embed_positions = InformerSinusoidalPositionalEmbedding( embed_positions = InformerSinusoidalPositionalEmbedding(
config.context_length + config.prediction_length, config.d_model config.context_length + config.prediction_length, config.d_model
) ).to(torch_device)
self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight)) self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight))
self.parent.assertTrue(torch.equal(model.decoder.embed_positions.weight, embed_positions.weight)) self.parent.assertTrue(torch.equal(model.decoder.embed_positions.weight, embed_positions.weight))

View File

@@ -412,7 +412,7 @@ class FlaxEncoderDecoderMixin:
# prepare inputs # prepare inputs
flax_inputs = inputs_dict flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
@@ -420,7 +420,7 @@ class FlaxEncoderDecoderMixin:
fx_outputs = fx_model(**inputs_dict).to_tuple() fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs): for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5) self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
# PT -> Flax # PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
@@ -430,7 +430,7 @@ class FlaxEncoderDecoderMixin:
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple() fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5) self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
# Flax -> PT # Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
@@ -445,7 +445,7 @@ class FlaxEncoderDecoderMixin:
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded): for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5) self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict): def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)

View File

@@ -241,7 +241,7 @@ class FlaxEncoderDecoderMixin:
# prepare inputs # prepare inputs
flax_inputs = inputs_dict flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
@@ -249,7 +249,7 @@ class FlaxEncoderDecoderMixin:
fx_outputs = fx_model(**inputs_dict).to_tuple() fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs): for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5) self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
# PT -> Flax # PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
@@ -259,7 +259,7 @@ class FlaxEncoderDecoderMixin:
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple() fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5) self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
# Flax -> PT # Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
@@ -274,7 +274,7 @@ class FlaxEncoderDecoderMixin:
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded): for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5) self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict): def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config) encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)

View File

@@ -160,7 +160,7 @@ class VisionTextDualEncoderMixin:
# prepare inputs # prepare inputs
flax_inputs = inputs_dict flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
@@ -168,7 +168,7 @@ class VisionTextDualEncoderMixin:
fx_outputs = fx_model(**inputs_dict).to_tuple() fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
# PT -> Flax # PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
@@ -178,7 +178,7 @@ class VisionTextDualEncoderMixin:
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple() fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]): for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
# Flax -> PT # Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
@@ -193,7 +193,7 @@ class VisionTextDualEncoderMixin:
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]): for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2) self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict): def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config) config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)

View File

@@ -179,7 +179,7 @@ class VisionTextDualEncoderMixin:
# prepare inputs # prepare inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values} inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
pt_inputs = inputs_dict pt_inputs = inputs_dict
flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()} flax_inputs = {k: v.numpy(force=True) for k, v in pt_inputs.items()}
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
@@ -187,7 +187,7 @@ class VisionTextDualEncoderMixin:
fx_outputs = fx_model(**flax_inputs).to_tuple() fx_outputs = fx_model(**flax_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
# PT -> Flax # PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
@@ -197,7 +197,7 @@ class VisionTextDualEncoderMixin:
fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple() fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]): for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
# Flax -> PT # Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
@@ -212,7 +212,7 @@ class VisionTextDualEncoderMixin:
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]): for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2) self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict): def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config) config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)