[FlaxSpeechEncoderDecoderModel] More Rigorous PT-Flax Equivalence Tests (#16589)
This commit is contained in:
@@ -413,28 +413,22 @@ class FlaxEncoderDecoderMixin:
|
|||||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs = pt_model(**pt_inputs)
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
pt_logits = pt_outputs.logits
|
|
||||||
pt_outputs = pt_outputs.to_tuple()
|
|
||||||
|
|
||||||
fx_outputs = fx_model(**inputs_dict)
|
|
||||||
fx_logits = fx_outputs.logits
|
|
||||||
fx_outputs = fx_outputs.to_tuple()
|
|
||||||
|
|
||||||
|
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||||
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
|
||||||
|
|
||||||
# PT -> Flax
|
# PT -> Flax
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
pt_model.save_pretrained(tmpdirname)
|
pt_model.save_pretrained(tmpdirname)
|
||||||
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
|
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||||
fx_logits_loaded = fx_outputs_loaded.logits
|
|
||||||
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
|
|
||||||
|
|
||||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
|
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||||
|
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
|
||||||
|
|
||||||
# Flax -> PT
|
# Flax -> PT
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
@@ -445,12 +439,11 @@ class FlaxEncoderDecoderMixin:
|
|||||||
pt_model_loaded.eval()
|
pt_model_loaded.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||||
pt_logits_loaded = pt_outputs_loaded.logits
|
|
||||||
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
|
|
||||||
|
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||||
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
|
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
|
||||||
|
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
|
||||||
|
|
||||||
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
|
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user