[Tests] Add Common Test for Training + Fix a couple of bugs (#8415)

* add training tests

* correct longformer

* fix docs

* fix some tests

* fix some more train tests

* remove ipdb

* fix multiple edge case model training

* fix funnel and prophetnet

* clean gpt models

* undo renaming of albert
This commit is contained in:
Patrick von Platen
2020-11-09 18:24:41 +01:00
committed by GitHub
parent 52040517b8
commit 9c83b96e62
30 changed files with 445 additions and 34 deletions

View File

@@ -35,10 +35,12 @@ if is_torch_available():
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
AdaptiveEmbedding,
BertConfig,
BertModel,
@@ -88,7 +90,10 @@ class ModelTesterMixin:
inputs_dict["end_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
elif model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
elif model_class in [
*MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(),
*MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(),
]:
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
@@ -204,6 +209,41 @@ class ModelTesterMixin:
expected_arg_names = ["input_ids"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_training(self):
if not self.model_tester.is_training:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in MODEL_MAPPING.values():
continue
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training or not hasattr(config, "gradient_checkpointing"):
return
config.gradient_checkpointing = True
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in MODEL_MAPPING.values():
continue
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True