enable misc cases on XPU & use device agnostic APIs for cases in tests (#38192)
* use device agnostic APIs in tests Signed-off-by: Matrix Yao <matrix.yao@intel.com> * more Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * add reset_peak_memory_stats API Signed-off-by: YAO Matrix <matrix.yao@intel.com> * update --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com> Signed-off-by: YAO Matrix <matrix.yao@intel.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -21,9 +21,11 @@ from transformers import is_torch_available
|
||||
from transformers.integrations.tensor_parallel import get_packed_weights, repack_weights
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
backend_device_count,
|
||||
get_torch_dist_unique_port,
|
||||
require_huggingface_hub_greater_or_equal,
|
||||
require_torch_multi_gpu,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@@ -168,4 +170,4 @@ class TestTensorParallel(TestCasePlus):
|
||||
|
||||
@require_torch_multi_gpu
|
||||
class TestTensorParallelCuda(TestTensorParallel):
|
||||
nproc_per_node = torch.cuda.device_count()
|
||||
nproc_per_node = backend_device_count(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user