Make PT/Flax tests could be run on GPU (#24557)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -611,7 +611,7 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
# convert inputs to Flax
|
# convert inputs to Flax
|
||||||
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||||
@@ -669,7 +669,7 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
|
|
||||||
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
|
|||||||
@@ -592,7 +592,7 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
# convert inputs to Flax
|
# convert inputs to Flax
|
||||||
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||||
@@ -650,7 +650,7 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
|
|
||||||
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
|
|||||||
@@ -875,7 +875,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
}
|
}
|
||||||
|
|
||||||
# convert inputs to Flax
|
# convert inputs to Flax
|
||||||
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
|
|
||||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||||
fx_model.params = fx_state
|
fx_model.params = fx_state
|
||||||
@@ -948,7 +948,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
}
|
}
|
||||||
|
|
||||||
# convert inputs to Flax
|
# convert inputs to Flax
|
||||||
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
|
|
||||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||||
|
|
||||||
@@ -1805,7 +1805,7 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
|||||||
}
|
}
|
||||||
|
|
||||||
# convert inputs to Flax
|
# convert inputs to Flax
|
||||||
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
|
|
||||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||||
fx_model.params = fx_state
|
fx_model.params = fx_state
|
||||||
@@ -1878,7 +1878,7 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
|||||||
}
|
}
|
||||||
|
|
||||||
# convert inputs to Flax
|
# convert inputs to Flax
|
||||||
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
|
|
||||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||||
|
|
||||||
|
|||||||
@@ -2206,7 +2206,7 @@ class ModelTesterMixin:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# convert inputs to Flax
|
# convert inputs to Flax
|
||||||
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
|
|
||||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||||
fx_model.params = fx_state
|
fx_model.params = fx_state
|
||||||
@@ -2278,7 +2278,7 @@ class ModelTesterMixin:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# convert inputs to Flax
|
# convert inputs to Flax
|
||||||
fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
|
||||||
|
|
||||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user