[Deepspeed Wav2vec2] integration (#11638)
* wip * wip - but working with https://github.com/microsoft/DeepSpeed/pull/1044 * cleanup * workaround * working 5/8 modes * solve fp32 distributed zero3 * style * sync * sync * rework * deprecation * cleanup * https://github.com/microsoft/DeepSpeed/pull/1044 pr was merged * clean up * add a guide * more prose * more prose * fix * more prose * sub_group_size was too big * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * refactor * bug fix * make the true check explicit * new deepspeed release Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -28,6 +28,7 @@ from typing import Iterator, Union
|
||||
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
from .deepspeed import is_deepspeed_available
|
||||
from .file_utils import (
|
||||
is_datasets_available,
|
||||
is_faiss_available,
|
||||
@@ -454,6 +455,16 @@ def require_soundfile(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_deepspeed(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires deepspeed
|
||||
"""
|
||||
if not is_deepspeed_available():
|
||||
return unittest.skip("test requires deepspeed")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def get_gpu_count():
|
||||
"""
|
||||
Return the number of available gpus (regardless of whether torch or tf is used)
|
||||
|
||||
Reference in New Issue
Block a user