[Flax] Add other BERT classes (#10977)

* add first code structures

* add all bert models

* add to init and docs

* correct docs

* make style
This commit is contained in:
Patrick von Platen
2021-03-31 09:45:58 +03:00
committed by GitHub
parent e031162a6b
commit e87505f3a1
7 changed files with 627 additions and 24 deletions

View File

@@ -23,7 +23,15 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
if is_flax_available():
from transformers.models.bert.modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
from transformers.models.bert.modeling_flax_bert import (
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction,
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForTokenClassification,
FlaxBertModel,
)
class FlaxBertModelTester(unittest.TestCase):
@@ -48,6 +56,7 @@ class FlaxBertModelTester(unittest.TestCase):
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_choices=4,
):
self.parent = parent
self.batch_size = batch_size
@@ -68,6 +77,7 @@ class FlaxBertModelTester(unittest.TestCase):
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_choices = num_choices
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -107,7 +117,20 @@ class FlaxBertModelTester(unittest.TestCase):
@require_flax
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxBertModel, FlaxBertForMaskedLM) if is_flax_available() else ()
all_model_classes = (
(
FlaxBertModel,
FlaxBertForPreTraining,
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,
FlaxBertForQuestionAnswering,
FlaxBertForNextSentencePrediction,
FlaxBertForTokenClassification,
FlaxBertForQuestionAnswering,
)
if is_flax_available()
else ()
)
def setUp(self):
self.model_tester = FlaxBertModelTester(self)