Enable users to use their own loss functions + deal with prefetching for grad accum (#34198)
* bookmark * Bookmark * Bookmark * Actually implement * Pass in kwarg explicitly * Adjust for if we do or don't have labels * Bookmark fix for od * bookmark * Fin * closer * Negate accelerate grad accum div * Fixup not training long enough * Add in compute_loss to take full model output * Document * compute_loss -> compute_loss_fn * Add a test * Refactor * Refactor * Uncomment tests * Update tests/trainer/test_trainer.py Co-authored-by: Daniel Han <danielhanchen@gmail.com> --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
This commit is contained in:
@@ -42,6 +42,7 @@ from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
DataCollatorForLanguageModeling,
|
||||
IntervalStrategy,
|
||||
PretrainedConfig,
|
||||
TrainerCallback,
|
||||
@@ -49,6 +50,7 @@ from transformers import (
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
is_torch_available,
|
||||
logging,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS
|
||||
from transformers.testing_utils import (
|
||||
@@ -153,6 +155,19 @@ if is_accelerate_available():
|
||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||
|
||||
|
||||
class StoreLossCallback(TrainerCallback):
|
||||
"""
|
||||
Simple callback to store the loss.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.losses = []
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if "loss" in logs:
|
||||
self.losses.append(logs["loss"])
|
||||
|
||||
|
||||
class MockCudaOOMCallback(TrainerCallback):
|
||||
"""
|
||||
Simple callback to simulate CUDA OOM error if
|
||||
@@ -168,6 +183,26 @@ class MockCudaOOMCallback(TrainerCallback):
|
||||
raise RuntimeError("CUDA out of memory.")
|
||||
|
||||
|
||||
def ForCausalLMLoss(logits, labels, vocab_size, num_items_in_batch, disable_num_items_in_batch=False):
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
# Flatten the tokens
|
||||
shift_logits = shift_logits.view(-1, vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if num_items_in_batch is None or disable_num_items_in_batch:
|
||||
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="mean")
|
||||
else:
|
||||
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="sum")
|
||||
loss = loss / num_items_in_batch
|
||||
return loss
|
||||
|
||||
|
||||
class RegressionDataset:
|
||||
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
|
||||
np.random.seed(seed)
|
||||
@@ -438,6 +473,31 @@ if is_torch_available():
|
||||
loss = nn.functional.mse_loss(y, labels)
|
||||
return (loss, y)
|
||||
|
||||
class BasicTextGenerationModel(nn.Module):
|
||||
def __init__(self, vocab_size, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
||||
self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
|
||||
self.fc = nn.Linear(hidden_size, vocab_size)
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
embedded = self.embedding(input_ids)
|
||||
lstm_out, _ = self.lstm(embedded)
|
||||
logits = self.fc(lstm_out)
|
||||
return logits
|
||||
|
||||
def create_dummy_dataset_for_text_generation(vocab_size, seq_length, num_samples):
|
||||
import datasets
|
||||
import numpy as np
|
||||
|
||||
# Create random input sequences
|
||||
input_ids = np.random.randint(0, vocab_size, (num_samples, seq_length))
|
||||
|
||||
# Create a datasets.Dataset
|
||||
dataset = datasets.Dataset.from_dict({"input_ids": input_ids, "labels": input_ids})
|
||||
|
||||
return dataset
|
||||
|
||||
class TstLayer(nn.Module):
|
||||
def __init__(self, hidden_size):
|
||||
super().__init__()
|
||||
@@ -676,8 +736,105 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||
|
||||
@slow
|
||||
def test_gradient_accumulation_loss_alignment(self):
|
||||
set_seed(42)
|
||||
import datasets
|
||||
|
||||
model_name = "distilgpt2"
|
||||
dataset_name = "wikitext"
|
||||
dataset_config = "wikitext-2-raw-v1"
|
||||
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
|
||||
dataset = dataset.train_test_split(test_size=0.2)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples["text"])
|
||||
|
||||
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
|
||||
def compute_loss(logits, labels, vocab_size, num_items_in_batch, disable_num_items_in_batch=False):
|
||||
return ForCausalLMLoss(
|
||||
logits["logits"], labels, vocab_size, num_items_in_batch, disable_num_items_in_batch
|
||||
)
|
||||
|
||||
loss_fn = partial(compute_loss, vocab_size=model.config.vocab_size, disable_num_items_in_batch=False)
|
||||
|
||||
base_loss_callback = StoreLossCallback()
|
||||
|
||||
args_kwargs = {
|
||||
"report_to": "none",
|
||||
"logging_steps": 1,
|
||||
"max_steps": 20,
|
||||
"learning_rate": 3e-4,
|
||||
"disable_tqdm": True,
|
||||
}
|
||||
|
||||
args = TrainingArguments(
|
||||
"./generation",
|
||||
**args_kwargs,
|
||||
)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=tokenized_dataset["train"],
|
||||
callbacks=[base_loss_callback],
|
||||
compute_loss_func=loss_fn,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
grad_accum_loss_callback = StoreLossCallback()
|
||||
args = TrainingArguments(
|
||||
"./generation",
|
||||
**args_kwargs,
|
||||
gradient_accumulation_steps=2,
|
||||
per_device_train_batch_size=4,
|
||||
)
|
||||
set_seed(42)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=tokenized_dataset["train"],
|
||||
callbacks=[grad_accum_loss_callback],
|
||||
compute_loss_func=loss_fn,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
set_seed(42)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
broken_loss_callback = StoreLossCallback()
|
||||
loss_fn = partial(compute_loss, vocab_size=model.config.vocab_size, disable_num_items_in_batch=True)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=tokenized_dataset["train"],
|
||||
callbacks=[broken_loss_callback],
|
||||
compute_loss_func=loss_fn,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Calculate the difference between the base loss and the grad_accum loss
|
||||
diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)]
|
||||
diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
|
||||
# These should be quite close
|
||||
for diff in diff_truth:
|
||||
self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1")
|
||||
|
||||
# These should be very off
|
||||
for diff in diff_broken:
|
||||
self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1")
|
||||
|
||||
def test_gradient_accumulation(self):
|
||||
# Training with half the batch size but accumulation steps as 2 should give the same results.
|
||||
# Training with half the batch size but accumulation steps as 2 should give the same training losses.
|
||||
trainer = get_regression_trainer(
|
||||
gradient_accumulation_steps=2, per_device_train_batch_size=4, learning_rate=0.1
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user