Add AMP for Albert (#10141)

This commit is contained in:
Julien Plu
2021-02-15 17:18:33 +01:00
committed by GitHub
parent 6fc940ed09
commit 31b0560ab4
8 changed files with 415 additions and 345 deletions

View File

@@ -26,6 +26,7 @@ from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers import TF_MODEL_FOR_PRETRAINING_MAPPING
from transformers.models.albert.modeling_tf_albert import (
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAlbertForMaskedLM,
@@ -243,6 +244,16 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
test_head_masking = False
test_onnx = False
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values():
inputs_dict["sentence_order_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
return inputs_dict
def setUp(self):
self.model_tester = TFAlbertModelTester(self)
self.config_tester = ConfigTester(self, config_class=AlbertConfig, hidden_size=37)
@@ -295,10 +306,6 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias()
assert name is None
def test_mixed_precision(self):
# TODO JP: Make ALBERT float16 compliant
pass
@slow
def test_model_from_pretrained(self):
for model_name in TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: