Make gradient_checkpointing a training argument (#13657)
* Make gradient_checkpointing a training argument * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Update src/transformers/configuration_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Fix tests * Style * document Gradient Checkpointing as a performance feature * Small rename * PoC for not using the config * Adapt BC to new PoC * Forgot to save * Rollout changes to all other models * Fix typo Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
@@ -96,7 +96,7 @@ class GPT2ModelTester:
|
||||
def get_large_model_config(self):
|
||||
return GPT2Config.from_pretrained("gpt2")
|
||||
|
||||
def prepare_config_and_inputs(self, gradient_checkpointing=False):
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
@@ -119,7 +119,7 @@ class GPT2ModelTester:
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config(gradient_checkpointing=gradient_checkpointing)
|
||||
config = self.get_config()
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
||||
@@ -135,7 +135,7 @@ class GPT2ModelTester:
|
||||
choice_labels,
|
||||
)
|
||||
|
||||
def get_config(self, gradient_checkpointing=False):
|
||||
def get_config(self):
|
||||
return GPT2Config(
|
||||
vocab_size=self.vocab_size,
|
||||
n_embd=self.hidden_size,
|
||||
@@ -149,11 +149,10 @@ class GPT2ModelTester:
|
||||
n_ctx=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
use_cache=not gradient_checkpointing,
|
||||
use_cache=True,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
@@ -322,9 +321,13 @@ class GPT2ModelTester:
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
def create_and_check_forward_and_backwards(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
|
||||
):
|
||||
model = GPT2LMHeadModel(config)
|
||||
model.to(torch_device)
|
||||
if gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
@@ -478,8 +481,8 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
self.model_tester.create_and_check_gpt2_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_gpt2_gradient_checkpointing(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
|
||||
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
|
||||
|
||||
@slow
|
||||
def test_batch_generation(self):
|
||||
@@ -612,7 +615,11 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_gpt2(self):
|
||||
for checkpointing in [True, False]:
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2", gradient_checkpointing=checkpointing)
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
if checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
else:
|
||||
model.gradient_checkpointing_disable()
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
|
||||
expected_output_ids = [
|
||||
|
||||
Reference in New Issue
Block a user