[Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054)
* save intermediate * save intermediate * save intermediate * correct flax bert model file * new module / model naming * make style * almost finish BERT * finish roberta * make fix-copies * delete keys file * last refactor * fixes in run_mlm_flax.py * remove pooled from run_mlm_flax.py` * fix gelu | gelu_new * remove Module from inits * splits * dirty print * preventing warmup_steps == 0 * smaller splits * make fix-copies * dirty print * dirty print * initial_evaluation argument * declaration order fix * proper model initialization/loading * proper initialization * run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug * removed tokenizers warning hack, fixed model re-initialization * reverted training_args.py changes * fix flax from pretrained * improve test in flax * apply sylvains tips * update init * make 0.3.0 compatible * revert tevens changes * revert tevens changes 2 * finalize revert * fix bug * add docs * add pretrained to init * Update src/transformers/modeling_flax_utils.py * fix copies * final improvements Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
committed by
GitHub
parent
51adb97cd6
commit
640e6fe190
@@ -10,7 +10,7 @@ deps = {
|
||||
"fastapi": "fastapi",
|
||||
"filelock": "filelock",
|
||||
"flake8": "flake8>=3.8.3",
|
||||
"flax": "flax==0.2.2",
|
||||
"flax": "flax>=0.2.2",
|
||||
"fugashi": "fugashi>=1.0",
|
||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||
"isort": "isort>=5.5.4",
|
||||
@@ -40,8 +40,8 @@ deps = {
|
||||
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
|
||||
"sphinx": "sphinx==3.2.1",
|
||||
"starlette": "starlette",
|
||||
"tensorflow-cpu": "tensorflow-cpu>=2.0",
|
||||
"tensorflow": "tensorflow>=2.0",
|
||||
"tensorflow-cpu": "tensorflow-cpu>=2.0,<2.4",
|
||||
"tensorflow": "tensorflow>=2.0,<2.4",
|
||||
"timeout-decorator": "timeout-decorator",
|
||||
"tokenizers": "tokenizers==0.9.4",
|
||||
"torch": "torch>=1.0",
|
||||
|
||||
Reference in New Issue
Block a user