torch.distributed group initialization for torch_neuron disabled when optimum-neuron is installed (#22728)
* Make the process group initialization not happen if optimum_neuron is installed * Add warning * Remove list and added warning
This commit is contained in:
@@ -54,8 +54,13 @@ from .utils import (
|
|||||||
logging,
|
logging,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
)
|
)
|
||||||
|
from .utils.import_utils import is_optimum_neuron_available
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
log_levels = logging.get_log_levels_dict().copy()
|
||||||
|
trainer_log_levels = dict(**log_levels, passive=-1)
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -67,12 +72,23 @@ if is_torch_neuroncore_available(check_device=False):
|
|||||||
# torchrun support
|
# torchrun support
|
||||||
# https://github.com/pytorch/xla/pull/3609
|
# https://github.com/pytorch/xla/pull/3609
|
||||||
if os.environ.get("TORCHELASTIC_RUN_ID"):
|
if os.environ.get("TORCHELASTIC_RUN_ID"):
|
||||||
import torch_xla.distributed.xla_backend as xbn
|
if is_optimum_neuron_available():
|
||||||
|
logger.info(
|
||||||
|
"Make sure that you are performing the training with the TrainiumTrainer from optimum[neuron], this "
|
||||||
|
"will fail otherwise."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Please use the TrainiumTrainer from optimum[neuron] instead of the Transformers library to perform "
|
||||||
|
"training on AWS Trainium instances. More information here: "
|
||||||
|
"https://github.com/huggingface/optimum-neuron"
|
||||||
|
)
|
||||||
|
import torch_xla.distributed.xla_backend as xbn
|
||||||
|
|
||||||
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
|
|
||||||
torch.distributed.init_process_group(backend="xla")
|
|
||||||
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
|
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
|
||||||
raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.")
|
torch.distributed.init_process_group(backend="xla")
|
||||||
|
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
|
||||||
|
raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.")
|
||||||
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
@@ -81,11 +97,6 @@ if is_sagemaker_mp_enabled():
|
|||||||
smp.init()
|
smp.init()
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
log_levels = logging.get_log_levels_dict().copy()
|
|
||||||
trainer_log_levels = dict(**log_levels, passive=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def default_logdir() -> str:
|
def default_logdir() -> str:
|
||||||
"""
|
"""
|
||||||
Same default as PyTorch
|
Same default as PyTorch
|
||||||
|
|||||||
@@ -583,6 +583,10 @@ def is_optimum_available():
|
|||||||
return importlib.util.find_spec("optimum") is not None
|
return importlib.util.find_spec("optimum") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_optimum_neuron_available():
|
||||||
|
return importlib.util.find_spec("optimum.neuron") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_safetensors_available():
|
def is_safetensors_available():
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
if version.parse(_torch_version) >= version.parse("1.10"):
|
if version.parse(_torch_version) >= version.parse("1.10"):
|
||||||
|
|||||||
Reference in New Issue
Block a user