Make PT/Flax tests could be run on GPU (#24557)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-06-28 20:11:01 +02:00
committed by GitHub
parent faae8d8255
commit fd6735102a
4 changed files with 10 additions and 10 deletions

View File

@@ -2206,7 +2206,7 @@ class ModelTesterMixin:
}
# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
@@ -2278,7 +2278,7 @@ class ModelTesterMixin:
}
# convert inputs to Flax
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)