[FlaxRoberta] Add FlaxRobertaModels & adapt run_mlm_flax.py (#11470)

* add flax roberta

* make style

* correct initialiazation

* modify model to save weights

* fix copied from

* fix copied from

* correct some more code

* add more roberta models

* Apply suggestions from code review

* merge from master

* finish

* finish docs

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen
2021-05-04 19:57:59 +02:00
committed by GitHub
parent 2ce0fb84cc
commit 084a187da3
11 changed files with 695 additions and 47 deletions

View File

@@ -150,7 +150,7 @@ class FlaxModelTesterMixin:
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
@@ -161,7 +161,7 @@ class FlaxModelTesterMixin:
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-3)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
@@ -191,7 +191,7 @@ class FlaxModelTesterMixin:
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
@@ -204,7 +204,7 @@ class FlaxModelTesterMixin:
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -219,6 +219,7 @@ class FlaxModelTesterMixin:
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**prepared_inputs_dict).to_tuple()
# verify that normal save_pretrained works as expected
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_loaded = model_class.from_pretrained(tmpdirname)
@@ -227,6 +228,16 @@ class FlaxModelTesterMixin:
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 1e-3)
# verify that save_pretrained for distributed training
# with `params=params` works as expected
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, params=model.params)
model_loaded = model_class.from_pretrained(tmpdirname)
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 1e-3)
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -23,7 +23,14 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
if is_flax_available():
from transformers.models.roberta.modeling_flax_roberta import FlaxRobertaModel
from transformers.models.roberta.modeling_flax_roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
)
class FlaxRobertaModelTester(unittest.TestCase):
@@ -48,6 +55,7 @@ class FlaxRobertaModelTester(unittest.TestCase):
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_choices=4,
):
self.parent = parent
self.batch_size = batch_size
@@ -68,6 +76,7 @@ class FlaxRobertaModelTester(unittest.TestCase):
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_choices = num_choices
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -107,7 +116,18 @@ class FlaxRobertaModelTester(unittest.TestCase):
@require_flax
class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxRobertaModel,) if is_flax_available() else ()
all_model_classes = (
(
FlaxRobertaModel,
FlaxRobertaForMaskedLM,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
)
if is_flax_available()
else ()
)
def setUp(self):
self.model_tester = FlaxRobertaModelTester(self)