Avoid accessing .dataset of a DataLoader in Trainer (#16451)
* Avoid accessing .dataset of a dataloader * style * fix * cleaning up, reverting some misunderstandings * black * add train_dataset argument to get_train_dataloader, and fix other instances of length checks * flake8 * address comments * fix bug * cleanup * add test * Update tests/trainer/test_trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * under torch * merge * stylistic suggestion Co-authored-by: Sander Land <sander@chatdesk.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -585,7 +585,7 @@ class Trainer:
|
|||||||
return dataset.remove_columns(ignored_columns)
|
return dataset.remove_columns(ignored_columns)
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
if not has_length(self.train_dataset):
|
if self.train_dataset is None or not has_length(self.train_dataset):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
generator = None
|
generator = None
|
||||||
@@ -661,8 +661,8 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
Returns the training [`~torch.utils.data.DataLoader`].
|
Returns the training [`~torch.utils.data.DataLoader`].
|
||||||
|
|
||||||
Will use no sampler if `self.train_dataset` does not implement `__len__`, a random sampler (adapted to
|
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
|
||||||
distributed training if necessary) otherwise.
|
training if necessary) otherwise.
|
||||||
|
|
||||||
Subclass and override this method if you want to inject some custom behavior.
|
Subclass and override this method if you want to inject some custom behavior.
|
||||||
"""
|
"""
|
||||||
@@ -937,11 +937,13 @@ class Trainer:
|
|||||||
|
|
||||||
def num_examples(self, dataloader: DataLoader) -> int:
|
def num_examples(self, dataloader: DataLoader) -> int:
|
||||||
"""
|
"""
|
||||||
Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset.
|
Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
|
||||||
|
dataloader.dataset does not exist or has no length, estimates as best it can
|
||||||
Will raise an exception if the underlying dataset does not implement method `__len__`
|
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
return len(dataloader.dataset)
|
return len(dataloader.dataset)
|
||||||
|
except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader
|
||||||
|
return len(dataloader) * self.args.per_device_train_batch_size
|
||||||
|
|
||||||
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
||||||
"""HP search setup code"""
|
"""HP search setup code"""
|
||||||
@@ -1198,9 +1200,6 @@ class Trainer:
|
|||||||
self._move_model_to_device(self.model, args.device)
|
self._move_model_to_device(self.model, args.device)
|
||||||
self.model_wrapped = self.model
|
self.model_wrapped = self.model
|
||||||
|
|
||||||
# Keeping track whether we can can len() on the dataset or not
|
|
||||||
train_dataset_is_sized = has_length(self.train_dataset)
|
|
||||||
|
|
||||||
# Data loader and number of training steps
|
# Data loader and number of training steps
|
||||||
train_dataloader = self.get_train_dataloader()
|
train_dataloader = self.get_train_dataloader()
|
||||||
|
|
||||||
@@ -1209,28 +1208,36 @@ class Trainer:
|
|||||||
# number of training steps per epoch: num_update_steps_per_epoch
|
# number of training steps per epoch: num_update_steps_per_epoch
|
||||||
# total number of training steps to execute: max_steps
|
# total number of training steps to execute: max_steps
|
||||||
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
|
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
|
||||||
if train_dataset_is_sized:
|
|
||||||
num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps
|
len_dataloader = None
|
||||||
|
if has_length(train_dataloader):
|
||||||
|
len_dataloader = len(train_dataloader)
|
||||||
|
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
|
||||||
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
||||||
|
num_examples = self.num_examples(train_dataloader)
|
||||||
if args.max_steps > 0:
|
if args.max_steps > 0:
|
||||||
max_steps = args.max_steps
|
max_steps = args.max_steps
|
||||||
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
|
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
|
||||||
args.max_steps % num_update_steps_per_epoch > 0
|
args.max_steps % num_update_steps_per_epoch > 0
|
||||||
)
|
)
|
||||||
# May be slightly incorrect if the last batch in the training datalaoder has a smaller size but it's
|
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
|
||||||
# the best we can do.
|
# the best we can do.
|
||||||
num_train_samples = args.max_steps * total_train_batch_size
|
num_train_samples = args.max_steps * total_train_batch_size
|
||||||
else:
|
else:
|
||||||
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
|
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
|
||||||
num_train_epochs = math.ceil(args.num_train_epochs)
|
num_train_epochs = math.ceil(args.num_train_epochs)
|
||||||
num_train_samples = len(self.train_dataset) * args.num_train_epochs
|
num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
|
||||||
else:
|
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
|
||||||
# see __init__. max_steps is set when the dataset has no __len__
|
|
||||||
max_steps = args.max_steps
|
max_steps = args.max_steps
|
||||||
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
|
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
|
||||||
num_train_epochs = sys.maxsize
|
num_train_epochs = sys.maxsize
|
||||||
num_update_steps_per_epoch = max_steps
|
num_update_steps_per_epoch = max_steps
|
||||||
|
num_examples = total_train_batch_size * args.max_steps
|
||||||
num_train_samples = args.max_steps * total_train_batch_size
|
num_train_samples = args.max_steps * total_train_batch_size
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"args.max_steps must be set to a positive value if dataloader does not have a length, was {args.max_steps}"
|
||||||
|
)
|
||||||
|
|
||||||
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
|
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
|
||||||
if self.args.n_gpu > 1:
|
if self.args.n_gpu > 1:
|
||||||
@@ -1281,10 +1288,6 @@ class Trainer:
|
|||||||
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
|
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
num_examples = (
|
|
||||||
self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(f" Num examples = {num_examples}")
|
logger.info(f" Num examples = {num_examples}")
|
||||||
logger.info(f" Num Epochs = {num_train_epochs}")
|
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||||
@@ -1370,7 +1373,7 @@ class Trainer:
|
|||||||
for epoch in range(epochs_trained, num_train_epochs):
|
for epoch in range(epochs_trained, num_train_epochs):
|
||||||
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
||||||
train_dataloader.sampler.set_epoch(epoch)
|
train_dataloader.sampler.set_epoch(epoch)
|
||||||
elif isinstance(train_dataloader.dataset, IterableDatasetShard):
|
elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
|
||||||
train_dataloader.dataset.set_epoch(epoch)
|
train_dataloader.dataset.set_epoch(epoch)
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
@@ -1384,7 +1387,9 @@ class Trainer:
|
|||||||
self._past = None
|
self._past = None
|
||||||
|
|
||||||
steps_in_epoch = (
|
steps_in_epoch = (
|
||||||
len(epoch_iterator) if train_dataset_is_sized else args.max_steps * args.gradient_accumulation_steps
|
len(epoch_iterator)
|
||||||
|
if len_dataloader is not None
|
||||||
|
else args.max_steps * args.gradient_accumulation_steps
|
||||||
)
|
)
|
||||||
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
|
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
|
||||||
|
|
||||||
@@ -2407,10 +2412,10 @@ class Trainer:
|
|||||||
elif args.bf16_full_eval:
|
elif args.bf16_full_eval:
|
||||||
model = model.to(dtype=torch.bfloat16, device=args.device)
|
model = model.to(dtype=torch.bfloat16, device=args.device)
|
||||||
|
|
||||||
batch_size = dataloader.batch_size
|
batch_size = self.args.per_device_eval_batch_size
|
||||||
|
|
||||||
logger.info(f"***** Running {description} *****")
|
logger.info(f"***** Running {description} *****")
|
||||||
if has_length(dataloader.dataset):
|
if has_length(dataloader):
|
||||||
logger.info(f" Num examples = {self.num_examples(dataloader)}")
|
logger.info(f" Num examples = {self.num_examples(dataloader)}")
|
||||||
else:
|
else:
|
||||||
logger.info(" Num examples: Unknown")
|
logger.info(" Num examples: Unknown")
|
||||||
@@ -2420,7 +2425,7 @@ class Trainer:
|
|||||||
|
|
||||||
self.callback_handler.eval_dataloader = dataloader
|
self.callback_handler.eval_dataloader = dataloader
|
||||||
# Do this before wrapping.
|
# Do this before wrapping.
|
||||||
eval_dataset = dataloader.dataset
|
eval_dataset = getattr(dataloader, "dataset", None)
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
|
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
|
||||||
@@ -2512,6 +2517,9 @@ class Trainer:
|
|||||||
elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"):
|
elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"):
|
||||||
num_samples = eval_dataset.num_examples
|
num_samples = eval_dataset.num_examples
|
||||||
else:
|
else:
|
||||||
|
if has_length(dataloader):
|
||||||
|
num_samples = self.num_examples(dataloader)
|
||||||
|
else: # both len(dataloader.dataset) and len(dataloader) fail
|
||||||
num_samples = observed_num_examples
|
num_samples = observed_num_examples
|
||||||
|
|
||||||
# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
|
# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
|
||||||
@@ -2899,8 +2907,9 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
args = self.args
|
args = self.args
|
||||||
|
|
||||||
if not has_length(dataloader.dataset):
|
if not has_length(dataloader):
|
||||||
raise ValueError("dataset must implement __len__")
|
raise ValueError("dataloader must implement a working __len__")
|
||||||
|
|
||||||
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
|
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
|
||||||
|
|
||||||
# if eval is called w/o train init deepspeed here
|
# if eval is called w/o train init deepspeed here
|
||||||
|
|||||||
@@ -473,7 +473,7 @@ class ProgressCallback(TrainerCallback):
|
|||||||
self.current_step = state.global_step
|
self.current_step = state.global_step
|
||||||
|
|
||||||
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
||||||
if state.is_local_process_zero and has_length(eval_dataloader.dataset):
|
if state.is_local_process_zero and has_length(eval_dataloader):
|
||||||
if self.prediction_bar is None:
|
if self.prediction_bar is None:
|
||||||
self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
|
self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
|
||||||
self.prediction_bar.update(1)
|
self.prediction_bar.update(1)
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import collections
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -21,7 +20,7 @@ from typing import Optional
|
|||||||
import IPython.display as disp
|
import IPython.display as disp
|
||||||
|
|
||||||
from ..trainer_callback import TrainerCallback
|
from ..trainer_callback import TrainerCallback
|
||||||
from ..trainer_utils import IntervalStrategy
|
from ..trainer_utils import IntervalStrategy, has_length
|
||||||
|
|
||||||
|
|
||||||
def format_time(t):
|
def format_time(t):
|
||||||
@@ -294,7 +293,7 @@ class NotebookProgressCallback(TrainerCallback):
|
|||||||
self._force_next_update = False
|
self._force_next_update = False
|
||||||
|
|
||||||
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
||||||
if not isinstance(eval_dataloader.dataset, collections.abc.Sized):
|
if not has_length(eval_dataloader):
|
||||||
return
|
return
|
||||||
if self.prediction_bar is None:
|
if self.prediction_bar is None:
|
||||||
if self.training_tracker is not None:
|
if self.training_tracker is not None:
|
||||||
|
|||||||
@@ -189,6 +189,26 @@ if is_torch_available():
|
|||||||
yield self.dataset[self.current_sample]
|
yield self.dataset[self.current_sample]
|
||||||
self.current_sample += 1
|
self.current_sample += 1
|
||||||
|
|
||||||
|
class MultiLoader:
|
||||||
|
def __init__(self, loaders):
|
||||||
|
self.loaders = loaders
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return sum(len(loader) for loader in self.loaders)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for loader in self.loaders:
|
||||||
|
yield from loader
|
||||||
|
|
||||||
|
class CustomDataloaderTrainer(Trainer):
|
||||||
|
def get_train_dataloader(self):
|
||||||
|
dataloaders = [super().get_train_dataloader(), super().get_train_dataloader()]
|
||||||
|
return MultiLoader(dataloaders)
|
||||||
|
|
||||||
|
def get_eval_dataloader(self, eval_dataset):
|
||||||
|
dataloaders = [super().get_eval_dataloader(eval_dataset), super().get_eval_dataloader(eval_dataset)]
|
||||||
|
return MultiLoader(dataloaders)
|
||||||
|
|
||||||
class RegressionModel(nn.Module):
|
class RegressionModel(nn.Module):
|
||||||
def __init__(self, a=0, b=0, double_output=False):
|
def __init__(self, a=0, b=0, double_output=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -647,6 +667,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
new_eval_dataset = RegressionDataset(length=128)
|
new_eval_dataset = RegressionDataset(length=128)
|
||||||
self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu))
|
self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu))
|
||||||
|
|
||||||
|
# tests that we do not require dataloader to have a .dataset attribute
|
||||||
|
def test_dataloader_without_dataset(self):
|
||||||
|
train_dataset = RegressionDataset(length=128)
|
||||||
|
trainer = CustomDataloaderTrainer(
|
||||||
|
model=RegressionModel(), train_dataset=train_dataset, eval_dataset=train_dataset
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
trainer.evaluate()
|
||||||
|
|
||||||
def test_sampler_seed(self):
|
def test_sampler_seed(self):
|
||||||
# nb: we don't want to inherit from IterableDataset to hit the right code path
|
# nb: we don't want to inherit from IterableDataset to hit the right code path
|
||||||
class DummyDataset(torch.utils.data.Dataset):
|
class DummyDataset(torch.utils.data.Dataset):
|
||||||
|
|||||||
Reference in New Issue
Block a user