Fix data parallelism in Trainer (#9566)
* Fix data parallelism in Trainer * Update src/transformers/training_args.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -405,7 +405,7 @@ class TrainingArguments:
|
||||
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
|
||||
)
|
||||
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."})
|
||||
_n_gpu: int = field(init=False, repr=False, default=0)
|
||||
_n_gpu: int = field(init=False, repr=False, default=-1)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.disable_tqdm is None:
|
||||
@@ -483,6 +483,10 @@ class TrainingArguments:
|
||||
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
|
||||
# will use the first GPU in that env, i.e. GPU#1
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
|
||||
# the default value.
|
||||
if self._n_gpu == -1:
|
||||
self._n_gpu = torch.cuda.device_count()
|
||||
n_gpu = self._n_gpu
|
||||
else:
|
||||
# Here, we'll use torch.distributed.
|
||||
|
||||
Reference in New Issue
Block a user