[Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054)

* save intermediate

* save intermediate

* save intermediate

* correct flax bert model file

* new module / model naming

* make style

* almost finish BERT

* finish roberta

* make fix-copies

* delete keys file

* last refactor

* fixes in run_mlm_flax.py

* remove pooled from run_mlm_flax.py`

* fix gelu | gelu_new

* remove Module from inits

* splits

* dirty print

* preventing warmup_steps == 0

* smaller splits

* make fix-copies

* dirty print

* dirty print

* initial_evaluation argument

* declaration order fix

* proper model initialization/loading

* proper initialization

* run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug

* removed tokenizers warning hack, fixed model re-initialization

* reverted training_args.py changes

* fix flax from pretrained

* improve test in flax

* apply sylvains tips

* update init

* make 0.3.0 compatible

* revert tevens changes

* revert tevens changes 2

* finalize revert

* fix bug

* add docs

* add pretrained to init

* Update src/transformers/modeling_flax_utils.py

* fix copies

* final improvements

Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
Patrick von Platen
2020-12-16 13:03:32 +01:00
committed by GitHub
parent 51adb97cd6
commit 640e6fe190
14 changed files with 700 additions and 359 deletions

View File

@@ -14,14 +14,16 @@
import unittest
import numpy as np
from transformers import BertConfig, is_flax_available
from transformers.testing_utils import require_flax
from transformers.testing_utils import require_flax, slow
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
from transformers.models.bert.modeling_flax_bert import FlaxBertModel
from transformers.models.bert.modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
class FlaxBertModelTester(unittest.TestCase):
@@ -105,7 +107,14 @@ class FlaxBertModelTester(unittest.TestCase):
@require_flax
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxBertModel,) if is_flax_available() else ()
all_model_classes = (FlaxBertModel, FlaxBertForMaskedLM) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxBertModelTester(self)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("bert-base-cased")
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)