[Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054)

* save intermediate

* save intermediate

* save intermediate

* correct flax bert model file

* new module / model naming

* make style

* almost finish BERT

* finish roberta

* make fix-copies

* delete keys file

* last refactor

* fixes in run_mlm_flax.py

* remove pooled from run_mlm_flax.py`

* fix gelu | gelu_new

* remove Module from inits

* splits

* dirty print

* preventing warmup_steps == 0

* smaller splits

* make fix-copies

* dirty print

* dirty print

* initial_evaluation argument

* declaration order fix

* proper model initialization/loading

* proper initialization

* run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug

* removed tokenizers warning hack, fixed model re-initialization

* reverted training_args.py changes

* fix flax from pretrained

* improve test in flax

* apply sylvains tips

* update init

* make 0.3.0 compatible

* revert tevens changes

* revert tevens changes 2

* finalize revert

* fix bug

* add docs

* add pretrained to init

* Update src/transformers/modeling_flax_utils.py

* fix copies

* final improvements

Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
Patrick von Platen
2020-12-16 13:03:32 +01:00
committed by GitHub
parent 51adb97cd6
commit 640e6fe190
14 changed files with 700 additions and 359 deletions

View File

@@ -385,7 +385,7 @@ def training_step(optimizer, batch, dropout_rng):
# Hide away tokens which doesn't participate in the optimization
token_mask = jnp.where(targets > 0, 1.0, 0.0)
pooled, logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)
logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss, weight_sum = cross_entropy(logits, targets, token_mask)
return loss / weight_sum
@@ -407,7 +407,7 @@ def eval_step(params, batch):
# Hide away tokens which doesn't participate in the optimization
token_mask = jnp.where(targets > 0, 1.0, 0.0)
_, logits = model(**batch, params=params, train=False)
logits = model(**batch, params=params, train=False)[0]
return compute_metrics(logits, targets, token_mask)
@@ -572,8 +572,13 @@ if __name__ == "__main__":
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
model = FlaxBertForMaskedLM.from_pretrained("bert-base-cased", dtype=jnp.float32, dropout_rate=0.1)
model.init(jax.random.PRNGKey(training_args.seed), (training_args.train_batch_size, model.config.max_length))
model = FlaxBertForMaskedLM.from_pretrained(
"bert-base-cased",
dtype=jnp.float32,
input_shape=(training_args.train_batch_size, config.max_position_embeddings),
seed=training_args.seed,
dropout_rate=0.1,
)
# Setup optimizer
optimizer = Adam(