From 7bbdfd7b240141dd578610e0bcae2fe62cc13451 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 22 May 2023 15:39:47 +0200 Subject: [PATCH] Fix accelerate logger bug (#23650) * fix logger bug * Update tests/mixed_int8/test_mixed_int8.py Co-authored-by: Zachary Mueller * import `PartialState` --------- Co-authored-by: Zachary Mueller --- tests/mixed_int8/test_mixed_int8.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 1628e08155..7053be30a1 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -29,6 +29,7 @@ from transformers import ( pipeline, ) from transformers.testing_utils import ( + is_accelerate_available, is_torch_available, require_accelerate, require_bitsandbytes, @@ -40,6 +41,13 @@ from transformers.testing_utils import ( from transformers.utils.versions import importlib_metadata +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.logging import get_logger + + logger = get_logger(__name__) + _ = PartialState() + if is_torch_available(): import torch import torch.nn as nn