[Trainer] Add nan/inf logging filter (#13619)
* finish * add test * push * remove unnecessary code * up * correct test * Update src/transformers/training_args.py
This commit is contained in:
committed by
GitHub
parent
eae7a96b7d
commit
1f9dcfc1ef
@@ -15,6 +15,7 @@
|
||||
|
||||
import dataclasses
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
@@ -528,6 +529,31 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, 10)
|
||||
|
||||
def test_logging_inf_nan_filter(self):
|
||||
config = GPT2Config(vocab_size=100, n_positions=128, n_ctx=128, n_embd=32, n_layer=3, n_head=4)
|
||||
tiny_gpt2 = GPT2LMHeadModel(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
train_dataset = RepeatDataset(x)
|
||||
|
||||
# Trainer without inf/nan filter
|
||||
args = TrainingArguments("./test", learning_rate=1e9, logging_steps=5, logging_nan_inf_filter=False)
|
||||
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
log_history_no_filter = trainer.state.log_history
|
||||
|
||||
# Trainer with inf/nan filter
|
||||
args = TrainingArguments("./test", learning_rate=1e9, logging_steps=5, logging_nan_inf_filter=True)
|
||||
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
log_history_filter = trainer.state.log_history
|
||||
|
||||
def is_any_loss_nan_or_inf(log_history):
|
||||
losses = [l["loss"] for l in log_history[:-1]]
|
||||
return any(math.isnan(x) for x in losses) or any(math.isinf(x) for x in losses)
|
||||
|
||||
self.assertTrue(is_any_loss_nan_or_inf(log_history_no_filter))
|
||||
self.assertFalse(is_any_loss_nan_or_inf(log_history_filter))
|
||||
|
||||
def test_train_and_eval_dataloaders(self):
|
||||
n_gpu = max(1, torch.cuda.device_count())
|
||||
trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16)
|
||||
|
||||
Reference in New Issue
Block a user