Fix TFEncoderDecoderModelTest - Pytorch device (#15979)
* fix device Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -323,6 +323,9 @@ class TFEncoderDecoderMixin:
|
|||||||
if "labels" in pt_inputs:
|
if "labels" in pt_inputs:
|
||||||
pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
|
pt_inputs["labels"] = pt_inputs["labels"].type(torch.LongTensor)
|
||||||
|
|
||||||
|
# send pytorch inputs to the correct device
|
||||||
|
pt_inputs = {k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()}
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
@@ -333,7 +336,7 @@ class TFEncoderDecoderMixin:
|
|||||||
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||||
|
|
||||||
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
|
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
|
||||||
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
|
self.assert_almost_equals(tf_output.numpy(), pt_output.detach().to("cpu").numpy(), 1e-3)
|
||||||
|
|
||||||
# PT -> TF
|
# PT -> TF
|
||||||
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||||
@@ -353,7 +356,7 @@ class TFEncoderDecoderMixin:
|
|||||||
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
|
||||||
|
|
||||||
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
|
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
|
||||||
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
|
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.detach().to("cpu").numpy(), 1e-3)
|
||||||
|
|
||||||
def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):
|
def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user