🚨🚨🚨 [Refactor] Move third-party related utility files into integrations/ folder 🚨🚨🚨 (#25599)
* move deepspeed to `lib_integrations.deepspeed` * more refactor * oops * fix slow tests * Fix docs * fix docs * addess feedback * address feedback * final modifs for PEFT * fixup * ok now * trigger CI * trigger CI again * Update docs/source/en/main_classes/deepspeed.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * import from `integrations` * address feedback * revert removal of `deepspeed` module * revert removal of `deepspeed` module * fix conflicts * ooops * oops * add deprecation warning * place it on the top * put `FutureWarning` * fix conflicts with not_doctested.txt * add back `bitsandbytes` module with a depr warning * fix * fix * fixup * oops * fix doctests --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -27,7 +27,11 @@ from parameterized import parameterized
|
||||
import tests.trainer.test_trainer
|
||||
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
|
||||
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
|
||||
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available, unset_hf_deepspeed_config
|
||||
from transformers.integrations.deepspeed import (
|
||||
HfDeepSpeedConfig,
|
||||
is_deepspeed_available,
|
||||
unset_hf_deepspeed_config,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
CaptureStd,
|
||||
@@ -113,7 +117,7 @@ def require_deepspeed_aio(test_case):
|
||||
if is_deepspeed_available():
|
||||
from deepspeed.utils import logger as deepspeed_logger # noqa
|
||||
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
|
||||
from transformers.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled # noqa
|
||||
from transformers.integrations.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled # noqa
|
||||
|
||||
|
||||
def get_launcher(distributed=False):
|
||||
|
||||
@@ -131,7 +131,7 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from transformers import AutoModelForMaskedLM, Blip2ForConditionalGeneration, MptForCausalLM, OPTForCausalLM
|
||||
from transformers.utils.bitsandbytes import get_keys_to_not_convert
|
||||
from transformers.integrations.bitsandbytes import get_keys_to_not_convert
|
||||
|
||||
model_id = "mosaicml/mpt-7b"
|
||||
config = AutoConfig.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user