[Flax] Add other BERT classes (#10977)
* add first code structures * add all bert models * add to init and docs * correct docs * make style
This commit is contained in:
committed by
GitHub
parent
e031162a6b
commit
e87505f3a1
@@ -23,7 +23,15 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers.models.bert.modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
|
||||
from transformers.models.bert.modeling_flax_bert import (
|
||||
FlaxBertForMaskedLM,
|
||||
FlaxBertForMultipleChoice,
|
||||
FlaxBertForNextSentencePrediction,
|
||||
FlaxBertForPreTraining,
|
||||
FlaxBertForQuestionAnswering,
|
||||
FlaxBertForTokenClassification,
|
||||
FlaxBertModel,
|
||||
)
|
||||
|
||||
|
||||
class FlaxBertModelTester(unittest.TestCase):
|
||||
@@ -48,6 +56,7 @@ class FlaxBertModelTester(unittest.TestCase):
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_choices=4,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -68,6 +77,7 @@ class FlaxBertModelTester(unittest.TestCase):
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_choices = num_choices
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -107,7 +117,20 @@ class FlaxBertModelTester(unittest.TestCase):
|
||||
@require_flax
|
||||
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (FlaxBertModel, FlaxBertForMaskedLM) if is_flax_available() else ()
|
||||
all_model_classes = (
|
||||
(
|
||||
FlaxBertModel,
|
||||
FlaxBertForPreTraining,
|
||||
FlaxBertForMaskedLM,
|
||||
FlaxBertForMultipleChoice,
|
||||
FlaxBertForQuestionAnswering,
|
||||
FlaxBertForNextSentencePrediction,
|
||||
FlaxBertForTokenClassification,
|
||||
FlaxBertForQuestionAnswering,
|
||||
)
|
||||
if is_flax_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxBertModelTester(self)
|
||||
|
||||
Reference in New Issue
Block a user