remove "inputs" in tf common test script (no longer required) (#15262)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -377,10 +377,6 @@ class TFModelTesterMixin:
|
|||||||
else:
|
else:
|
||||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||||
|
|
||||||
# need to rename encoder-decoder "inputs" for PyTorch
|
|
||||||
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
|
||||||
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pto = pt_model(**pt_inputs_dict)
|
pto = pt_model(**pt_inputs_dict)
|
||||||
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
|
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
|
||||||
@@ -422,9 +418,6 @@ class TFModelTesterMixin:
|
|||||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||||
else:
|
else:
|
||||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||||
# need to rename encoder-decoder "inputs" for PyTorch
|
|
||||||
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
|
||||||
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pto = pt_model(**pt_inputs_dict)
|
pto = pt_model(**pt_inputs_dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user