From d7c8ce57d43499f0042e63d5be2592c641354bad Mon Sep 17 00:00:00 2001 From: Sander Land <48946947+sanderland@users.noreply.github.com> Date: Tue, 29 Mar 2022 21:00:18 +0200 Subject: [PATCH] Avoid accessing .dataset of a DataLoader in Trainer (#16451) * Avoid accessing .dataset of a dataloader * style * fix * cleaning up, reverting some misunderstandings * black * add train_dataset argument to get_train_dataloader, and fix other instances of length checks * flake8 * address comments * fix bug * cleanup * add test * Update tests/trainer/test_trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * under torch * merge * stylistic suggestion Co-authored-by: Sander Land Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/trainer.py | 65 ++++++++++++++++------------ src/transformers/trainer_callback.py | 2 +- src/transformers/utils/notebook.py | 5 +-- tests/trainer/test_trainer.py | 29 +++++++++++++ 4 files changed, 69 insertions(+), 32 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1bf6fde9fc..993e524863 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -585,7 +585,7 @@ class Trainer: return dataset.remove_columns(ignored_columns) def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if not has_length(self.train_dataset): + if self.train_dataset is None or not has_length(self.train_dataset): return None generator = None @@ -661,8 +661,8 @@ class Trainer: """ Returns the training [`~torch.utils.data.DataLoader`]. - Will use no sampler if `self.train_dataset` does not implement `__len__`, a random sampler (adapted to - distributed training if necessary) otherwise. + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. Subclass and override this method if you want to inject some custom behavior. """ @@ -937,11 +937,13 @@ class Trainer: def num_examples(self, dataloader: DataLoader) -> int: """ - Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. - - Will raise an exception if the underlying dataset does not implement method `__len__` + Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When + dataloader.dataset does not exist or has no length, estimates as best it can """ - return len(dataloader.dataset) + try: + return len(dataloader.dataset) + except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader + return len(dataloader) * self.args.per_device_train_batch_size def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): """HP search setup code""" @@ -1198,9 +1200,6 @@ class Trainer: self._move_model_to_device(self.model, args.device) self.model_wrapped = self.model - # Keeping track whether we can can len() on the dataset or not - train_dataset_is_sized = has_length(self.train_dataset) - # Data loader and number of training steps train_dataloader = self.get_train_dataloader() @@ -1209,28 +1208,36 @@ class Trainer: # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size - if train_dataset_is_sized: - num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps + + len_dataloader = None + if has_length(train_dataloader): + len_dataloader = len(train_dataloader) + num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + num_examples = self.num_examples(train_dataloader) if args.max_steps > 0: max_steps = args.max_steps num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( args.max_steps % num_update_steps_per_epoch > 0 ) - # May be slightly incorrect if the last batch in the training datalaoder has a smaller size but it's + # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's # the best we can do. num_train_samples = args.max_steps * total_train_batch_size else: max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) num_train_epochs = math.ceil(args.num_train_epochs) - num_train_samples = len(self.train_dataset) * args.num_train_epochs - else: - # see __init__. max_steps is set when the dataset has no __len__ + num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs + elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size max_steps = args.max_steps # Setting a very large number of epochs so we go as many times as necessary over the iterator. num_train_epochs = sys.maxsize num_update_steps_per_epoch = max_steps + num_examples = total_train_batch_size * args.max_steps num_train_samples = args.max_steps * total_train_batch_size + else: + raise ValueError( + f"args.max_steps must be set to a positive value if dataloader does not have a length, was {args.max_steps}" + ) if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if self.args.n_gpu > 1: @@ -1281,10 +1288,6 @@ class Trainer: # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. # Train! - num_examples = ( - self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps - ) - logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples}") logger.info(f" Num Epochs = {num_train_epochs}") @@ -1370,7 +1373,7 @@ class Trainer: for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) - elif isinstance(train_dataloader.dataset, IterableDatasetShard): + elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): train_dataloader.dataset.set_epoch(epoch) if is_torch_tpu_available(): @@ -1384,7 +1387,9 @@ class Trainer: self._past = None steps_in_epoch = ( - len(epoch_iterator) if train_dataset_is_sized else args.max_steps * args.gradient_accumulation_steps + len(epoch_iterator) + if len_dataloader is not None + else args.max_steps * args.gradient_accumulation_steps ) self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) @@ -2407,10 +2412,10 @@ class Trainer: elif args.bf16_full_eval: model = model.to(dtype=torch.bfloat16, device=args.device) - batch_size = dataloader.batch_size + batch_size = self.args.per_device_eval_batch_size logger.info(f"***** Running {description} *****") - if has_length(dataloader.dataset): + if has_length(dataloader): logger.info(f" Num examples = {self.num_examples(dataloader)}") else: logger.info(" Num examples: Unknown") @@ -2420,7 +2425,7 @@ class Trainer: self.callback_handler.eval_dataloader = dataloader # Do this before wrapping. - eval_dataset = dataloader.dataset + eval_dataset = getattr(dataloader, "dataset", None) if is_torch_tpu_available(): dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) @@ -2512,7 +2517,10 @@ class Trainer: elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"): num_samples = eval_dataset.num_examples else: - num_samples = observed_num_examples + if has_length(dataloader): + num_samples = self.num_examples(dataloader) + else: # both len(dataloader.dataset) and len(dataloader) fail + num_samples = observed_num_examples # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of # samplers has been rounded to a multiple of batch_size, so we truncate. @@ -2899,8 +2907,9 @@ class Trainer: """ args = self.args - if not has_length(dataloader.dataset): - raise ValueError("dataset must implement __len__") + if not has_length(dataloader): + raise ValueError("dataloader must implement a working __len__") + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only # if eval is called w/o train init deepspeed here diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index ec344341bc..92abe1ed50 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -473,7 +473,7 @@ class ProgressCallback(TrainerCallback): self.current_step = state.global_step def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): - if state.is_local_process_zero and has_length(eval_dataloader.dataset): + if state.is_local_process_zero and has_length(eval_dataloader): if self.prediction_bar is None: self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None) self.prediction_bar.update(1) diff --git a/src/transformers/utils/notebook.py b/src/transformers/utils/notebook.py index 779446f5f1..0ffbdc8dee 100644 --- a/src/transformers/utils/notebook.py +++ b/src/transformers/utils/notebook.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import re import time from typing import Optional @@ -21,7 +20,7 @@ from typing import Optional import IPython.display as disp from ..trainer_callback import TrainerCallback -from ..trainer_utils import IntervalStrategy +from ..trainer_utils import IntervalStrategy, has_length def format_time(t): @@ -294,7 +293,7 @@ class NotebookProgressCallback(TrainerCallback): self._force_next_update = False def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): - if not isinstance(eval_dataloader.dataset, collections.abc.Sized): + if not has_length(eval_dataloader): return if self.prediction_bar is None: if self.training_tracker is not None: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ec044ea1da..afe97701d2 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -189,6 +189,26 @@ if is_torch_available(): yield self.dataset[self.current_sample] self.current_sample += 1 + class MultiLoader: + def __init__(self, loaders): + self.loaders = loaders + + def __len__(self): + return sum(len(loader) for loader in self.loaders) + + def __iter__(self): + for loader in self.loaders: + yield from loader + + class CustomDataloaderTrainer(Trainer): + def get_train_dataloader(self): + dataloaders = [super().get_train_dataloader(), super().get_train_dataloader()] + return MultiLoader(dataloaders) + + def get_eval_dataloader(self, eval_dataset): + dataloaders = [super().get_eval_dataloader(eval_dataset), super().get_eval_dataloader(eval_dataset)] + return MultiLoader(dataloaders) + class RegressionModel(nn.Module): def __init__(self, a=0, b=0, double_output=False): super().__init__() @@ -647,6 +667,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): new_eval_dataset = RegressionDataset(length=128) self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu)) + # tests that we do not require dataloader to have a .dataset attribute + def test_dataloader_without_dataset(self): + train_dataset = RegressionDataset(length=128) + trainer = CustomDataloaderTrainer( + model=RegressionModel(), train_dataset=train_dataset, eval_dataset=train_dataset + ) + trainer.train() + trainer.evaluate() + def test_sampler_seed(self): # nb: we don't want to inherit from IterableDataset to hit the right code path class DummyDataset(torch.utils.data.Dataset):