Add tests to Trainer (#6605)
* Add tests to Trainer * Test if removing long breaks everything * Remove ugly hack * Fix distributed test * Use float for number of epochs
This commit is contained in:
@@ -62,7 +62,7 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
|
||||
if isinstance(v, torch.Tensor):
|
||||
batch[k] = torch.stack([f[k] for f in features])
|
||||
else:
|
||||
batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long)
|
||||
batch[k] = torch.tensor([f[k] for f in features])
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@@ -449,6 +449,7 @@ class Trainer:
|
||||
else:
|
||||
t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
|
||||
num_train_epochs = self.args.num_train_epochs
|
||||
self.args.max_steps = t_total
|
||||
|
||||
self.create_optimizer_and_scheduler(num_training_steps=t_total)
|
||||
|
||||
@@ -530,7 +531,7 @@ class Trainer:
|
||||
logging_loss = 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(
|
||||
epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_process_zero()
|
||||
epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=not self.is_local_process_zero()
|
||||
)
|
||||
for epoch in train_iterator:
|
||||
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
||||
@@ -626,10 +627,10 @@ class Trainer:
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
|
||||
if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
|
||||
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
|
||||
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
if self.args.tpu_metrics_debug or self.args.debug:
|
||||
@@ -986,10 +987,13 @@ class Trainer:
|
||||
if self.args.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
samples_count = 0
|
||||
for inputs in tqdm(dataloader, desc=description):
|
||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
||||
batch_size = inputs[list(inputs.keys())[0]].shape[0]
|
||||
samples_count += batch_size
|
||||
if loss is not None:
|
||||
eval_losses.append(loss)
|
||||
eval_losses.append(loss * batch_size)
|
||||
if logits is not None:
|
||||
preds = logits if preds is None else torch.cat((preds, logits), dim=0)
|
||||
if labels is not None:
|
||||
@@ -1023,7 +1027,7 @@ class Trainer:
|
||||
else:
|
||||
metrics = {}
|
||||
if len(eval_losses) > 0:
|
||||
metrics["eval_loss"] = np.mean(eval_losses)
|
||||
metrics["eval_loss"] = np.sum(eval_losses) / samples_count
|
||||
|
||||
# Prefix all keys with eval_
|
||||
for key in list(metrics.keys()):
|
||||
|
||||
@@ -69,7 +69,8 @@ class TrainingArguments:
|
||||
max_grad_norm (:obj:`float`, `optional`, defaults to 1.0):
|
||||
Maximum gradient norm (for gradient clipping).
|
||||
num_train_epochs(:obj:`float`, `optional`, defaults to 3.0):
|
||||
Total number of training epochs to perform.
|
||||
Total number of training epochs to perform (if not an integer, will perform the decimal part percents of
|
||||
the last epoch before stopping training).
|
||||
max_steps (:obj:`int`, `optional`, defaults to -1):
|
||||
If set to a positive number, the total number of training steps to perform. Overrides
|
||||
:obj:`num_train_epochs`.
|
||||
|
||||
Reference in New Issue
Block a user