Pruning for GPT and GPT-2

This commit is contained in:
LysandreJik
2019-08-21 20:12:06 -04:00
parent fc1fbae45d
commit 719cb3738d
3 changed files with 24 additions and 5 deletions

View File

@@ -213,13 +213,12 @@ class CommonTestCases:
if not self.test_pruning:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
config.output_attentions = True
config.output_hidden_states = False
model = model_class(config=config)
@@ -244,6 +243,10 @@ class CommonTestCases:
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
config.output_attentions = True
config.output_hidden_states = False
model = model_class(config=config)
@@ -274,6 +277,10 @@ class CommonTestCases:
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
config.output_attentions = True
config.output_hidden_states = False