diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 1628e08155..7053be30a1 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -29,6 +29,7 @@ from transformers import ( pipeline, ) from transformers.testing_utils import ( + is_accelerate_available, is_torch_available, require_accelerate, require_bitsandbytes, @@ -40,6 +41,13 @@ from transformers.testing_utils import ( from transformers.utils.versions import importlib_metadata +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.logging import get_logger + + logger = get_logger(__name__) + _ = PartialState() + if is_torch_available(): import torch import torch.nn as nn