[FlaxRoberta] Add FlaxRobertaModels & adapt run_mlm_flax.py (#11470)
* add flax roberta * make style * correct initialiazation * modify model to save weights * fix copied from * fix copied from * correct some more code * add more roberta models * Apply suggestions from code review * merge from master * finish * finish docs Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
2ce0fb84cc
commit
084a187da3
@@ -23,7 +23,14 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers.models.roberta.modeling_flax_roberta import FlaxRobertaModel
|
||||
from transformers.models.roberta.modeling_flax_roberta import (
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
FlaxRobertaForQuestionAnswering,
|
||||
FlaxRobertaForSequenceClassification,
|
||||
FlaxRobertaForTokenClassification,
|
||||
FlaxRobertaModel,
|
||||
)
|
||||
|
||||
|
||||
class FlaxRobertaModelTester(unittest.TestCase):
|
||||
@@ -48,6 +55,7 @@ class FlaxRobertaModelTester(unittest.TestCase):
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_choices=4,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -68,6 +76,7 @@ class FlaxRobertaModelTester(unittest.TestCase):
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_choices = num_choices
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -107,7 +116,18 @@ class FlaxRobertaModelTester(unittest.TestCase):
|
||||
@require_flax
|
||||
class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (FlaxRobertaModel,) if is_flax_available() else ()
|
||||
all_model_classes = (
|
||||
(
|
||||
FlaxRobertaModel,
|
||||
FlaxRobertaForMaskedLM,
|
||||
FlaxRobertaForSequenceClassification,
|
||||
FlaxRobertaForTokenClassification,
|
||||
FlaxRobertaForMultipleChoice,
|
||||
FlaxRobertaForQuestionAnswering,
|
||||
)
|
||||
if is_flax_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxRobertaModelTester(self)
|
||||
|
||||
Reference in New Issue
Block a user