[Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054)
* save intermediate * save intermediate * save intermediate * correct flax bert model file * new module / model naming * make style * almost finish BERT * finish roberta * make fix-copies * delete keys file * last refactor * fixes in run_mlm_flax.py * remove pooled from run_mlm_flax.py` * fix gelu | gelu_new * remove Module from inits * splits * dirty print * preventing warmup_steps == 0 * smaller splits * make fix-copies * dirty print * dirty print * initial_evaluation argument * declaration order fix * proper model initialization/loading * proper initialization * run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug * removed tokenizers warning hack, fixed model re-initialization * reverted training_args.py changes * fix flax from pretrained * improve test in flax * apply sylvains tips * update init * make 0.3.0 compatible * revert tevens changes * revert tevens changes 2 * finalize revert * fix bug * add docs * add pretrained to init * Update src/transformers/modeling_flax_utils.py * fix copies * final improvements Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
committed by
GitHub
parent
51adb97cd6
commit
640e6fe190
@@ -14,14 +14,16 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import BertConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers.models.bert.modeling_flax_bert import FlaxBertModel
|
||||
from transformers.models.bert.modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
|
||||
|
||||
|
||||
class FlaxBertModelTester(unittest.TestCase):
|
||||
@@ -105,7 +107,14 @@ class FlaxBertModelTester(unittest.TestCase):
|
||||
@require_flax
|
||||
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (FlaxBertModel,) if is_flax_available() else ()
|
||||
all_model_classes = (FlaxBertModel, FlaxBertForMaskedLM) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxBertModelTester(self)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
model = model_class_name.from_pretrained("bert-base-cased")
|
||||
outputs = model(np.ones((1, 1)))
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
Reference in New Issue
Block a user