remove "inputs" in tf common test script (no longer required) (#15262)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -377,10 +377,6 @@ class TFModelTesterMixin:
|
||||
else:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_inputs_dict)
|
||||
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
|
||||
@@ -422,9 +418,6 @@ class TFModelTesterMixin:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
else:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_inputs_dict)
|
||||
|
||||
Reference in New Issue
Block a user