[TF-PT-Tests] Fix PyTorch - TF tests for different GPU devices (#15846)
This commit is contained in:
committed by
GitHub
parent
97f9b8a27b
commit
ddbb485c41
@@ -1493,9 +1493,8 @@ class ModelTesterMixin:
|
|||||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
||||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)
|
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)
|
||||||
|
|
||||||
# need to rename encoder-decoder "inputs" for PyTorch
|
# Make sure PyTorch tensors are on same device as model
|
||||||
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
pt_inputs = {k: v.to(torch_device) if torch.is_tensor(v) else v for k, v in pt_inputs.items()}
|
||||||
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pto = pt_model(**pt_inputs)
|
pto = pt_model(**pt_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user