[TF-PT-Tests] Fix PyTorch - TF tests for different GPU devices (#15846)
This commit is contained in:
committed by
GitHub
parent
97f9b8a27b
commit
ddbb485c41
@@ -1493,9 +1493,8 @@ class ModelTesterMixin:
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)
|
||||
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||
# Make sure PyTorch tensors are on same device as model
|
||||
pt_inputs = {k: v.to(torch_device) if torch.is_tensor(v) else v for k, v in pt_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_inputs)
|
||||
|
||||
Reference in New Issue
Block a user