Make gradient_checkpointing a training argument (#13657)

* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
Sylvain Gugger
2021-09-22 07:51:38 -04:00
committed by GitHub
parent 75f6641eaf
commit 27d4639779
96 changed files with 531 additions and 309 deletions

View File

@@ -224,6 +224,27 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss
loss.backward()
def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training:
return
config.use_cache = False
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue
# we don't test BeitForMaskedImageModeling
if model_class.__name__ == "BeitForMaskedImageModeling":
continue
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -370,15 +370,14 @@ class ModelTesterMixin:
def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training or not hasattr(config, "gradient_checkpointing"):
if not self.model_tester.is_training:
return
config.gradient_checkpointing = True
config.use_cache = False
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING):
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue
model = model_class(config)
model.to(torch_device)

View File

@@ -20,6 +20,7 @@ import unittest
from transformers import DeiTConfig
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from .test_configuration_common import ConfigTester
@@ -340,7 +341,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes:
# DeiTForImageClassificationWithTeacher supports inference-only
if (
model_class in MODEL_MAPPING.values()
model_class in get_values(MODEL_MAPPING)
or model_class.__name__ == "DeiTForImageClassificationWithTeacher"
):
continue
@@ -351,6 +352,27 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss
loss.backward()
def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training:
return
config.use_cache = False
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue
# DeiTForImageClassificationWithTeacher supports inference-only
if model_class.__name__ == "DeiTForImageClassificationWithTeacher":
continue
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)

View File

@@ -82,7 +82,7 @@ class FlaxGPT2ModelTester:
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
def prepare_config_and_inputs(self, gradient_checkpointing=False):
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
@@ -100,7 +100,6 @@ class FlaxGPT2ModelTester:
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
)
return (config, input_ids, input_mask)

View File

@@ -86,7 +86,7 @@ class FlaxGPTNeoModelTester:
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
def prepare_config_and_inputs(self, gradient_checkpointing=False):
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
@@ -105,7 +105,6 @@ class FlaxGPTNeoModelTester:
pad_token_id=self.pad_token_id,
window_size=self.window_size,
attention_types=self.attention_types,
gradient_checkpointing=gradient_checkpointing,
)
return (config, input_ids, input_mask)

View File

@@ -96,7 +96,7 @@ class GPT2ModelTester:
def get_large_model_config(self):
return GPT2Config.from_pretrained("gpt2")
def prepare_config_and_inputs(self, gradient_checkpointing=False):
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
@@ -119,7 +119,7 @@ class GPT2ModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config(gradient_checkpointing=gradient_checkpointing)
config = self.get_config()
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
@@ -135,7 +135,7 @@ class GPT2ModelTester:
choice_labels,
)
def get_config(self, gradient_checkpointing=False):
def get_config(self):
return GPT2Config(
vocab_size=self.vocab_size,
n_embd=self.hidden_size,
@@ -149,11 +149,10 @@ class GPT2ModelTester:
n_ctx=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
use_cache=not gradient_checkpointing,
use_cache=True,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
)
def prepare_config_and_inputs_for_decoder(self):
@@ -322,9 +321,13 @@ class GPT2ModelTester:
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
):
model = GPT2LMHeadModel(config)
model.to(torch_device)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertEqual(result.loss.shape, ())
@@ -478,8 +481,8 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_gpt2_for_token_classification(*config_and_inputs)
def test_gpt2_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
@slow
def test_batch_generation(self):
@@ -612,7 +615,11 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_gpt2(self):
for checkpointing in [True, False]:
model = GPT2LMHeadModel.from_pretrained("gpt2", gradient_checkpointing=checkpointing)
model = GPT2LMHeadModel.from_pretrained("gpt2")
if checkpointing:
model.gradient_checkpointing_enable()
else:
model.gradient_checkpointing_disable()
model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
expected_output_ids = [

View File

@@ -97,7 +97,7 @@ class GPTNeoModelTester:
def get_large_model_config(self):
return GPTNeoConfig.from_pretrained("gpt_neo")
def prepare_config_and_inputs(self, gradient_checkpointing=False):
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
@@ -120,7 +120,7 @@ class GPTNeoModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config(gradient_checkpointing=False)
config = self.get_config()
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
@@ -136,18 +136,17 @@ class GPTNeoModelTester:
choice_labels,
)
def get_config(self, gradient_checkpointing=False):
def get_config(self):
return GPTNeoConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_layers=self.num_hidden_layers,
num_heads=self.num_attention_heads,
max_position_embeddings=self.max_position_embeddings,
use_cache=not gradient_checkpointing,
use_cache=True,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
window_size=self.window_size,
attention_types=self.attention_types,
)
@@ -329,8 +328,12 @@ class GPTNeoModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
):
model = GPTNeoForCausalLM(config)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
model.to(torch_device)
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
@@ -411,8 +414,8 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs)
def test_gpt_neo_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
def _get_hidden_states(self):
return torch.tensor(
@@ -473,7 +476,10 @@ class GPTNeoModelLanguageGenerationTest(unittest.TestCase):
def test_lm_generate_gpt_neo(self):
for checkpointing in [True, False]:
model = self.model
model.config.gradient_checkpointing = checkpointing
if checkpointing:
model.gradient_checkpointing_enable()
else:
model.gradient_checkpointing_disable()
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
# fmt: off
# The dog-eared copy of the book, which is a collection of essays by the late author,

View File

@@ -92,7 +92,7 @@ class GPTJModelTester:
def get_large_model_config(self):
return GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
def prepare_config_and_inputs(self, gradient_checkpointing=False):
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
@@ -115,7 +115,7 @@ class GPTJModelTester:
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config(gradient_checkpointing=gradient_checkpointing)
config = self.get_config()
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
@@ -131,7 +131,7 @@ class GPTJModelTester:
choice_labels,
)
def get_config(self, gradient_checkpointing=False):
def get_config(self):
return GPTJConfig(
vocab_size=self.vocab_size,
n_embd=self.hidden_size,
@@ -145,11 +145,10 @@ class GPTJModelTester:
n_ctx=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
use_cache=not gradient_checkpointing,
use_cache=True,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
)
def prepare_config_and_inputs_for_decoder(self):
@@ -318,8 +317,12 @@ class GPTJModelTester:
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
):
model = GPTJForCausalLM(config)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
model.to(torch_device)
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
@@ -390,8 +393,8 @@ class GPTJModelTest(unittest.TestCase):
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
def test_gptj_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
@slow
def test_batch_generation(self):
@@ -464,7 +467,11 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_gptj(self):
for checkpointing in [True, False]:
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", gradient_checkpointing=checkpointing)
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
if checkpointing:
model.gradient_checkpointing_enable()
else:
model.gradient_checkpointing_disable()
model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
expected_output_ids = [