[Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054)
* save intermediate * save intermediate * save intermediate * correct flax bert model file * new module / model naming * make style * almost finish BERT * finish roberta * make fix-copies * delete keys file * last refactor * fixes in run_mlm_flax.py * remove pooled from run_mlm_flax.py` * fix gelu | gelu_new * remove Module from inits * splits * dirty print * preventing warmup_steps == 0 * smaller splits * make fix-copies * dirty print * dirty print * initial_evaluation argument * declaration order fix * proper model initialization/loading * proper initialization * run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug * removed tokenizers warning hack, fixed model re-initialization * reverted training_args.py changes * fix flax from pretrained * improve test in flax * apply sylvains tips * update init * make 0.3.0 compatible * revert tevens changes * revert tevens changes 2 * finalize revert * fix bug * add docs * add pretrained to init * Update src/transformers/modeling_flax_utils.py * fix copies * final improvements Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
committed by
GitHub
parent
51adb97cd6
commit
640e6fe190
@@ -385,7 +385,7 @@ def training_step(optimizer, batch, dropout_rng):
|
||||
# Hide away tokens which doesn't participate in the optimization
|
||||
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
||||
|
||||
pooled, logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)
|
||||
logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||
loss, weight_sum = cross_entropy(logits, targets, token_mask)
|
||||
return loss / weight_sum
|
||||
|
||||
@@ -407,7 +407,7 @@ def eval_step(params, batch):
|
||||
|
||||
# Hide away tokens which doesn't participate in the optimization
|
||||
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
||||
_, logits = model(**batch, params=params, train=False)
|
||||
logits = model(**batch, params=params, train=False)[0]
|
||||
|
||||
return compute_metrics(logits, targets, token_mask)
|
||||
|
||||
@@ -572,8 +572,13 @@ if __name__ == "__main__":
|
||||
rng = jax.random.PRNGKey(training_args.seed)
|
||||
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
|
||||
model = FlaxBertForMaskedLM.from_pretrained("bert-base-cased", dtype=jnp.float32, dropout_rate=0.1)
|
||||
model.init(jax.random.PRNGKey(training_args.seed), (training_args.train_batch_size, model.config.max_length))
|
||||
model = FlaxBertForMaskedLM.from_pretrained(
|
||||
"bert-base-cased",
|
||||
dtype=jnp.float32,
|
||||
input_shape=(training_args.train_batch_size, config.max_position_embeddings),
|
||||
seed=training_args.seed,
|
||||
dropout_rate=0.1,
|
||||
)
|
||||
|
||||
# Setup optimizer
|
||||
optimizer = Adam(
|
||||
|
||||
Reference in New Issue
Block a user