remove "inputs" in tf common test script (no longer required) (#15262)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-02-01 11:09:49 +01:00
committed by GitHub
parent d12ae81664
commit af5c3329d7

View File

@@ -377,10 +377,6 @@ class TFModelTesterMixin:
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
@@ -422,9 +418,6 @@ class TFModelTesterMixin:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)