Enable users to use their own loss functions + deal with prefetching for grad accum (#34198)

* bookmark

* Bookmark

* Bookmark

* Actually implement

* Pass in kwarg explicitly

* Adjust for if we do or don't have labels

* Bookmark fix for od

* bookmark

* Fin

* closer

* Negate accelerate grad accum div

* Fixup not training long enough

* Add in compute_loss to take full model output

* Document

* compute_loss -> compute_loss_fn

* Add a test

* Refactor

* Refactor

* Uncomment tests

* Update tests/trainer/test_trainer.py

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
This commit is contained in:
Zach Mueller
2024-10-17 17:01:56 -04:00
committed by GitHub
parent 7a06d07e14
commit 6ba31a8a94
2 changed files with 326 additions and 125 deletions

View File

@@ -42,6 +42,7 @@ from transformers import (
AutoImageProcessor,
AutoProcessor,
AutoTokenizer,
DataCollatorForLanguageModeling,
IntervalStrategy,
PretrainedConfig,
TrainerCallback,
@@ -49,6 +50,7 @@ from transformers import (
get_polynomial_decay_schedule_with_warmup,
is_torch_available,
logging,
set_seed,
)
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS
from transformers.testing_utils import (
@@ -153,6 +155,19 @@ if is_accelerate_available():
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
class StoreLossCallback(TrainerCallback):
"""
Simple callback to store the loss.
"""
def __init__(self):
self.losses = []
def on_log(self, args, state, control, logs=None, **kwargs):
if "loss" in logs:
self.losses.append(logs["loss"])
class MockCudaOOMCallback(TrainerCallback):
"""
Simple callback to simulate CUDA OOM error if
@@ -168,6 +183,26 @@ class MockCudaOOMCallback(TrainerCallback):
raise RuntimeError("CUDA out of memory.")
def ForCausalLMLoss(logits, labels, vocab_size, num_items_in_batch, disable_num_items_in_batch=False):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
if num_items_in_batch is None or disable_num_items_in_batch:
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="mean")
else:
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="sum")
loss = loss / num_items_in_batch
return loss
class RegressionDataset:
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
np.random.seed(seed)
@@ -438,6 +473,31 @@ if is_torch_available():
loss = nn.functional.mse_loss(y, labels)
return (loss, y)
class BasicTextGenerationModel(nn.Module):
def __init__(self, vocab_size, hidden_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, input_ids, **kwargs):
embedded = self.embedding(input_ids)
lstm_out, _ = self.lstm(embedded)
logits = self.fc(lstm_out)
return logits
def create_dummy_dataset_for_text_generation(vocab_size, seq_length, num_samples):
import datasets
import numpy as np
# Create random input sequences
input_ids = np.random.randint(0, vocab_size, (num_samples, seq_length))
# Create a datasets.Dataset
dataset = datasets.Dataset.from_dict({"input_ids": input_ids, "labels": input_ids})
return dataset
class TstLayer(nn.Module):
def __init__(self, hidden_size):
super().__init__()
@@ -676,8 +736,105 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train()
self.check_trained_model(trainer.model, alternate_seed=True)
@slow
def test_gradient_accumulation_loss_alignment(self):
set_seed(42)
import datasets
model_name = "distilgpt2"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
dataset = dataset.train_test_split(test_size=0.2)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_function(examples):
return tokenizer(examples["text"])
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
model = AutoModelForCausalLM.from_pretrained(model_name)
def compute_loss(logits, labels, vocab_size, num_items_in_batch, disable_num_items_in_batch=False):
return ForCausalLMLoss(
logits["logits"], labels, vocab_size, num_items_in_batch, disable_num_items_in_batch
)
loss_fn = partial(compute_loss, vocab_size=model.config.vocab_size, disable_num_items_in_batch=False)
base_loss_callback = StoreLossCallback()
args_kwargs = {
"report_to": "none",
"logging_steps": 1,
"max_steps": 20,
"learning_rate": 3e-4,
"disable_tqdm": True,
}
args = TrainingArguments(
"./generation",
**args_kwargs,
)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[base_loss_callback],
compute_loss_func=loss_fn,
data_collator=data_collator,
)
trainer.train()
grad_accum_loss_callback = StoreLossCallback()
args = TrainingArguments(
"./generation",
**args_kwargs,
gradient_accumulation_steps=2,
per_device_train_batch_size=4,
)
set_seed(42)
model = AutoModelForCausalLM.from_pretrained(model_name)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[grad_accum_loss_callback],
compute_loss_func=loss_fn,
data_collator=data_collator,
)
trainer.train()
set_seed(42)
model = AutoModelForCausalLM.from_pretrained(model_name)
broken_loss_callback = StoreLossCallback()
loss_fn = partial(compute_loss, vocab_size=model.config.vocab_size, disable_num_items_in_batch=True)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[broken_loss_callback],
compute_loss_func=loss_fn,
data_collator=data_collator,
)
trainer.train()
# Calculate the difference between the base loss and the grad_accum loss
diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)]
diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
# These should be quite close
for diff in diff_truth:
self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1")
# These should be very off
for diff in diff_broken:
self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1")
def test_gradient_accumulation(self):
# Training with half the batch size but accumulation steps as 2 should give the same results.
# Training with half the batch size but accumulation steps as 2 should give the same training losses.
trainer = get_regression_trainer(
gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1
)