Add tests to Trainer (#6605)

* Add tests to Trainer

* Test if removing long breaks everything

* Remove ugly hack

* Fix distributed test

* Use float for number of epochs
This commit is contained in:
Sylvain Gugger
2020-08-20 11:13:50 -04:00
committed by GitHub
parent 039d8d65fc
commit 573bdb0a5d
6 changed files with 313 additions and 136 deletions

View File

@@ -62,7 +62,7 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
else:
batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long)
batch[k] = torch.tensor([f[k] for f in features])
return batch

View File

@@ -449,6 +449,7 @@ class Trainer:
else:
t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
num_train_epochs = self.args.num_train_epochs
self.args.max_steps = t_total
self.create_optimizer_and_scheduler(num_training_steps=t_total)
@@ -530,7 +531,7 @@ class Trainer:
logging_loss = 0.0
model.zero_grad()
train_iterator = trange(
epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_process_zero()
epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=not self.is_local_process_zero()
)
for epoch in train_iterator:
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
@@ -626,10 +627,10 @@ class Trainer:
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
epoch_iterator.close()
break
if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
train_iterator.close()
break
if self.args.tpu_metrics_debug or self.args.debug:
@@ -986,10 +987,13 @@ class Trainer:
if self.args.past_index >= 0:
self._past = None
samples_count = 0
for inputs in tqdm(dataloader, desc=description):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0]
samples_count += batch_size
if loss is not None:
eval_losses.append(loss)
eval_losses.append(loss * batch_size)
if logits is not None:
preds = logits if preds is None else torch.cat((preds, logits), dim=0)
if labels is not None:
@@ -1023,7 +1027,7 @@ class Trainer:
else:
metrics = {}
if len(eval_losses) > 0:
metrics["eval_loss"] = np.mean(eval_losses)
metrics["eval_loss"] = np.sum(eval_losses) / samples_count
# Prefix all keys with eval_
for key in list(metrics.keys()):

View File

@@ -69,7 +69,8 @@ class TrainingArguments:
max_grad_norm (:obj:`float`, `optional`, defaults to 1.0):
Maximum gradient norm (for gradient clipping).
num_train_epochs(:obj:`float`, `optional`, defaults to 3.0):
Total number of training epochs to perform.
Total number of training epochs to perform (if not an integer, will perform the decimal part percents of
the last epoch before stopping training).
max_steps (:obj:`int`, `optional`, defaults to -1):
If set to a positive number, the total number of training steps to perform. Overrides
:obj:`num_train_epochs`.