[Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054)
* save intermediate * save intermediate * save intermediate * correct flax bert model file * new module / model naming * make style * almost finish BERT * finish roberta * make fix-copies * delete keys file * last refactor * fixes in run_mlm_flax.py * remove pooled from run_mlm_flax.py` * fix gelu | gelu_new * remove Module from inits * splits * dirty print * preventing warmup_steps == 0 * smaller splits * make fix-copies * dirty print * dirty print * initial_evaluation argument * declaration order fix * proper model initialization/loading * proper initialization * run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug * removed tokenizers warning hack, fixed model re-initialization * reverted training_args.py changes * fix flax from pretrained * improve test in flax * apply sylvains tips * update init * make 0.3.0 compatible * revert tevens changes * revert tevens changes 2 * finalize revert * fix bug * add docs * add pretrained to init * Update src/transformers/modeling_flax_utils.py * fix copies * final improvements Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
committed by
GitHub
parent
51adb97cd6
commit
640e6fe190
@@ -270,6 +270,7 @@ TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
TF2_WEIGHTS_NAME = "tf_model.h5"
|
||||
TF_WEIGHTS_NAME = "model.ckpt"
|
||||
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
||||
CONFIG_NAME = "config.json"
|
||||
MODEL_CARD_NAME = "modelcard.json"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user