diff --git a/.github/workflows/github-torch-hub.yml b/.github/workflows/github-torch-hub.yml index 858f7ebb0a..ace9e02963 100644 --- a/.github/workflows/github-torch-hub.yml +++ b/.github/workflows/github-torch-hub.yml @@ -21,7 +21,7 @@ jobs: - name: Install dependencies run: | pip install torch - pip install numpy tokenizers filelock requests tqdm regex sentencepiece sacremoses + pip install numpy tokenizers filelock requests tqdm regex sentencepiece sacremoses packaging - name: Torch hub list run: | diff --git a/examples/language-modeling/run_language_modeling.py b/examples/language-modeling/run_language_modeling.py index d2afa56acb..d9465c3763 100644 --- a/examples/language-modeling/run_language_modeling.py +++ b/examples/language-modeling/run_language_modeling.py @@ -251,7 +251,7 @@ def main(): # Evaluation results = {} - if training_args.do_eval and training_args.local_rank in [-1, 0]: + if training_args.do_eval: logger.info("*** Evaluate ***") eval_output = trainer.evaluate() @@ -260,11 +260,12 @@ def main(): result = {"perplexity": perplexity} output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt") - with open(output_eval_file, "w") as writer: - logger.info("***** Eval results *****") - for key in sorted(result.keys()): - logger.info(" %s = %s", key, str(result[key])) - writer.write("%s = %s\n" % (key, str(result[key]))) + if trainer.is_world_master(): + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results *****") + for key in sorted(result.keys()): + logger.info(" %s = %s", key, str(result[key])) + writer.write("%s = %s\n" % (key, str(result[key]))) results.update(result) diff --git a/examples/multiple-choice/run_multiple_choice.py b/examples/multiple-choice/run_multiple_choice.py index 9f95a27da1..f2147c44f0 100644 --- a/examples/multiple-choice/run_multiple_choice.py +++ b/examples/multiple-choice/run_multiple_choice.py @@ -202,19 +202,20 @@ def main(): # Evaluation results = {} - if training_args.do_eval and training_args.local_rank in [-1, 0]: + if training_args.do_eval: logger.info("*** Evaluate ***") result = trainer.evaluate() output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") - with open(output_eval_file, "w") as writer: - logger.info("***** Eval results *****") - for key, value in result.items(): - logger.info(" %s = %s", key, value) - writer.write("%s = %s\n" % (key, value)) + if trainer.is_world_master(): + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results *****") + for key, value in result.items(): + logger.info(" %s = %s", key, value) + writer.write("%s = %s\n" % (key, value)) - results.update(result) + results.update(result) return results diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 9bfe6aa288..080c648938 100644 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -166,7 +166,7 @@ def main(): # Evaluation results = {} - if training_args.do_eval and training_args.local_rank in [-1, 0]: + if training_args.do_eval: logger.info("*** Evaluate ***") # Loop to handle MNLI double evaluation (matched, mis-matched) @@ -181,11 +181,12 @@ def main(): output_eval_file = os.path.join( training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt" ) - with open(output_eval_file, "w") as writer: - logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name)) - for key, value in result.items(): - logger.info(" %s = %s", key, value) - writer.write("%s = %s\n" % (key, value)) + if trainer.is_world_master(): + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name)) + for key, value in result.items(): + logger.info(" %s = %s", key, value) + writer.write("%s = %s\n" % (key, value)) results.update(result) diff --git a/examples/token-classification/run_ner.py b/examples/token-classification/run_ner.py index bb99a08b8e..7f79fb6d4e 100644 --- a/examples/token-classification/run_ner.py +++ b/examples/token-classification/run_ner.py @@ -235,22 +235,23 @@ def main(): # Evaluation results = {} - if training_args.do_eval and training_args.local_rank in [-1, 0]: + if training_args.do_eval: logger.info("*** Evaluate ***") result = trainer.evaluate() output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") - with open(output_eval_file, "w") as writer: - logger.info("***** Eval results *****") - for key, value in result.items(): - logger.info(" %s = %s", key, value) - writer.write("%s = %s\n" % (key, value)) + if trainer.is_world_master(): + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results *****") + for key, value in result.items(): + logger.info(" %s = %s", key, value) + writer.write("%s = %s\n" % (key, value)) results.update(result) # Predict - if training_args.do_predict and training_args.local_rank in [-1, 0]: + if training_args.do_predict: test_dataset = NerDataset( data_dir=data_args.data_dir, tokenizer=tokenizer, @@ -265,26 +266,30 @@ def main(): preds_list, _ = align_predictions(predictions, label_ids) output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt") - with open(output_test_results_file, "w") as writer: - for key, value in metrics.items(): - logger.info(" %s = %s", key, value) - writer.write("%s = %s\n" % (key, value)) + if trainer.is_world_master(): + with open(output_test_results_file, "w") as writer: + for key, value in metrics.items(): + logger.info(" %s = %s", key, value) + writer.write("%s = %s\n" % (key, value)) # Save predictions output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt") - with open(output_test_predictions_file, "w") as writer: - with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f: - example_id = 0 - for line in f: - if line.startswith("-DOCSTART-") or line == "" or line == "\n": - writer.write(line) - if not preds_list[example_id]: - example_id += 1 - elif preds_list[example_id]: - output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n" - writer.write(output_line) - else: - logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]) + if trainer.is_world_master(): + with open(output_test_predictions_file, "w") as writer: + with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f: + example_id = 0 + for line in f: + if line.startswith("-DOCSTART-") or line == "" or line == "\n": + writer.write(line) + if not preds_list[example_id]: + example_id += 1 + elif preds_list[example_id]: + output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n" + writer.write(output_line) + else: + logger.warning( + "Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0] + ) return results diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f836987c28..8db4eb0b81 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1,5 +1,6 @@ import json import logging +import math import os import random import re @@ -15,7 +16,7 @@ from torch import nn from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import RandomSampler +from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler from tqdm.auto import tqdm, trange from .data.data_collator import DataCollator, DefaultDataCollator @@ -90,7 +91,7 @@ def set_seed(seed: int): @contextmanager def torch_distributed_zero_first(local_rank: int): """ - Decorator to make all processes in distributed training wait for the first one (locally) to do something. + Decorator to make all processes in distributed training wait for each local_master to do something. """ if local_rank not in [-1, 0]: torch.distributed.barrier() @@ -99,6 +100,50 @@ def torch_distributed_zero_first(local_rank: int): torch.distributed.barrier() +class SequentialDistributedSampler(Sampler): + """ + Distributed Sampler that subsamples indicies sequentially, + making it easier to collate all results at the end. + + Even though we only use this sampler for eval and predict (no training), + which means that the model params won't have to be synced (i.e. will not hang + for synchronization even if varied number of forward passes), we still add extra + samples to the sampler to make it evenly divisible (like in `DistributedSampler`) + to make it easy to `gather` or `reduce` resulting tensors at the end of the loop. + """ + + def __init__(self, dataset, num_replicas=None, rank=None): + if num_replicas is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = torch.distributed.get_world_size() + if rank is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = torch.distributed.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + indices = list(range(len(self.dataset))) + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def get_tpu_sampler(dataset: Dataset): if xm.xrt_world_size() <= 1: return RandomSampler(dataset) @@ -156,7 +201,7 @@ class Trainer: self.optimizers = optimizers if tb_writer is not None: self.tb_writer = tb_writer - elif is_tensorboard_available() and self.args.local_rank in [-1, 0]: + elif is_tensorboard_available() and self.is_world_master(): self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir) if not is_tensorboard_available(): logger.warning( @@ -171,7 +216,7 @@ class Trainer: ) set_seed(self.args.seed) # Create output directory if needed - if self.is_local_master(): + if self.is_world_master(): os.makedirs(self.args.output_dir, exist_ok=True) if is_tpu_available(): # Set an xla_device flag on the model's config. @@ -208,13 +253,19 @@ class Trainer: eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - sampler = get_tpu_sampler(eval_dataset) if is_tpu_available() else None + if is_tpu_available(): + sampler = SequentialDistributedSampler( + eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() + ) + elif self.args.local_rank != -1: + sampler = SequentialDistributedSampler(eval_dataset) + else: + sampler = SequentialSampler(eval_dataset) data_loader = DataLoader( eval_dataset, sampler=sampler, batch_size=self.args.eval_batch_size, - shuffle=False, collate_fn=self.data_collator.collate_batch, ) @@ -225,13 +276,19 @@ class Trainer: def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: # We use the same batch_size as for eval. - sampler = get_tpu_sampler(test_dataset) if is_tpu_available() else None + if is_tpu_available(): + sampler = SequentialDistributedSampler( + test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() + ) + elif self.args.local_rank != -1: + sampler = SequentialDistributedSampler(test_dataset) + else: + sampler = SequentialSampler(test_dataset) data_loader = DataLoader( test_dataset, sampler=sampler, batch_size=self.args.eval_batch_size, - shuffle=False, collate_fn=self.data_collator.collate_batch, ) @@ -405,6 +462,9 @@ class Trainer: epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master() ) for epoch in train_iterator: + if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): + train_dataloader.sampler.set_epoch(epoch) + epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master()) for step, inputs in enumerate(epoch_iterator): @@ -435,27 +495,25 @@ class Trainer: self.global_step += 1 self.epoch = epoch + (step + 1) / len(epoch_iterator) - if self.is_local_master(): - if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or ( - self.global_step == 1 and self.args.logging_first_step - ): - logs: Dict[str, float] = {} - logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps - # maintaining backward compatibility. - # could use "scheduler.get_last_lr()[0]" instead for pytorch >= 1.4.0 - logs["learning_rate"] = ( - scheduler.get_last_lr()[0] - if version.parse(torch.__version__) >= version.parse("1.4") - else scheduler.get_lr()[0] - ) + if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or ( + self.global_step == 1 and self.args.logging_first_step + ): + logs: Dict[str, float] = {} + logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps + # backward compatibility for pytorch schedulers + logs["learning_rate"] = ( + scheduler.get_last_lr()[0] + if version.parse(torch.__version__) >= version.parse("1.4") + else scheduler.get_lr()[0] + ) + logging_loss = tr_loss - logging_loss = tr_loss + self._log(logs) - self._log(logs) - - if self.args.evaluate_during_training: - self.evaluate() + if self.args.evaluate_during_training: + self.evaluate() + if self.is_world_master(): if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0: # In all cases (even distributed/parallel), self.model is always a reference # to the model we want to save. @@ -548,7 +606,7 @@ class Trainer: Saving best-practices: if you use default names for the model, you can reload it using from_pretrained(). - Will only save from the master process. + Will only save from the world_master process (unless in TPUs). """ if is_tpu_available(): @@ -667,12 +725,15 @@ class Trainer: prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only + model = self.model + model.to(self.args.device) # multi-gpu eval - if self.args.n_gpu > 1 and not isinstance(self.model, torch.nn.DataParallel): - model = torch.nn.DataParallel(self.model) + if self.args.n_gpu > 1: + model = torch.nn.DataParallel(model) else: model = self.model - model.to(self.args.device) + # Note: in torch.distributed mode, there's no point in wrapping the model + # inside a DistributedDataParallel as we'll be under `no_grad` anyways. if is_tpu_available(): batch_size = dataloader._loader._loader.batch_size @@ -682,8 +743,8 @@ class Trainer: logger.info(" Num examples = %d", self.num_examples(dataloader)) logger.info(" Batch size = %d", batch_size) eval_losses: List[float] = [] - preds: np.ndarray = None - label_ids: np.ndarray = None + preds: torch.Tensor = None + label_ids: torch.Tensor = None model.eval() for inputs in tqdm(dataloader, desc=description): @@ -702,19 +763,33 @@ class Trainer: if not prediction_loss_only: if preds is None: - preds = logits.detach().cpu().numpy() + preds = logits.detach() else: - preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) + preds = torch.cat((preds, logits.detach()), dim=0) if inputs.get("labels") is not None: if label_ids is None: - label_ids = inputs["labels"].detach().cpu().numpy() + label_ids = inputs["labels"].detach() else: - label_ids = np.append(label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) + label_ids = torch.cat((label_ids, inputs["labels"].detach()), dim=0) - if is_tpu_available() and preds is not None and label_ids is not None: + if self.args.local_rank != -1: + # In distributed mode, concatenate all results from all nodes: + if preds is not None: + preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader)) + if label_ids is not None: + label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader)) + elif is_tpu_available(): # tpu-comment: Get all predictions and labels from all worker shards of eval dataset - preds = xm.mesh_reduce("eval_preds", preds, np.concatenate) - label_ids = xm.mesh_reduce("eval_out_label_ids", label_ids, np.concatenate) + if preds is not None: + preds = xm.mesh_reduce("eval_preds", preds, torch.cat) + if label_ids is not None: + label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) + + # Finally, turn the aggregated tensors into numpy arrays. + if preds is not None: + preds = preds.cpu().numpy() + if label_ids is not None: + label_ids = label_ids.cpu().numpy() if self.compute_metrics is not None and preds is not None and label_ids is not None: metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) @@ -729,3 +804,15 @@ class Trainer: metrics[f"eval_{key}"] = metrics.pop(key) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) + + def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor: + assert self.args.local_rank != -1 + + output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(output_tensors, tensor) + + concat = torch.cat(output_tensors, dim=0) + + # truncate the dummy elements added by SequentialDistributedSampler + output = concat[:num_total_examples] + return output diff --git a/tests/test_trainer_distributed.py b/tests/test_trainer_distributed.py new file mode 100644 index 0000000000..78bc144312 --- /dev/null +++ b/tests/test_trainer_distributed.py @@ -0,0 +1,102 @@ +# This test is meant to be run in torch.distributed, +# on a machine with multiple GPUs, in the following way: +# +# python -m torch.distributed.launch --nproc_per_node 2 ./tests/test_trainer_distributed.py +# +# Replace 2 with the number of GPUs you have. +# +# You can also run it as a standalone file to test identical behavior in nn.DataParallel: +# python ./tests/test_trainer_distributed.py +# and in single-GPU mode: +# CUDA_VISIBLE_DEVICES=0 python ./tests/test_trainer_distributed.py +# + + +import logging +import sys +from typing import Dict + +from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available + + +logger = logging.getLogger(__name__) + + +if is_torch_available(): + import torch + from torch import nn + from torch.utils.data.dataset import Dataset + + from transformers import DataCollator, Trainer + + class DummyDataset(Dataset): + def __init__(self, length: int = 101): + self.length = length + + def __len__(self): + return self.length + + def __getitem__(self, i) -> int: + return i + + class DummyDataCollator(DataCollator): + def collate_batch(self, features): + return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)} + + class DummyModel(nn.Module): + def __init__(self): + super().__init__() + # Add some (unused) params otherwise DDP will complain. + self.fc = nn.Linear(120, 80) + + def forward(self, input_ids, labels=None): + if labels is not None: + return torch.tensor(0.0, device=input_ids.device), input_ids + else: + return input_ids + + +if __name__ == "__main__": + parser = HfArgumentParser((TrainingArguments,)) + training_args = parser.parse_args_into_dataclasses(sys.argv + ["--output_dir", "./examples"])[0] + + logging.basicConfig(level=logging.INFO) + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s", + training_args.local_rank, + training_args.device, + training_args.n_gpu, + training_args.local_rank != -1, + ) + + # Essentially, what we want to verify in the distributed case is + # that we get all samples back, in the right order. + # (this is crucial for prediction for instance) + for dataset_length in [101, 40, 7]: + dataset = DummyDataset(dataset_length) + + def compute_metrics(p: EvalPrediction) -> Dict: + sequential = list(range(len(dataset))) + success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential + return {"success": success} + + trainer = Trainer( + model=DummyModel(), + args=training_args, + data_collator=DummyDataCollator(), + eval_dataset=dataset, + compute_metrics=compute_metrics, + ) + metrics = trainer.evaluate() + logger.info(metrics) + if metrics["eval_success"] is not True: + logger.error(metrics) + exit(1) + + p = trainer.predict(dataset) + logger.info(p.metrics) + if p.metrics["eval_success"] is not True: + logger.error(p.metrics) + exit(1) + + logger.info("🔥 All distributed tests successful")