From 7d321b7689c6ddc4d33ff93693a9e6ec2c65c51e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 7 Jul 2021 14:43:43 +0100 Subject: [PATCH] [Flax] Allow retraining from save checkpoint (#12559) * fix_torch_device_generate_test * remove @ * finish --- examples/flax/language-modeling/run_mlm_flax.py | 9 ++++++++- examples/flax/language-modeling/run_t5_mlm_flax.py | 7 ++++++- .../dataset-streaming/run_mlm_flax_stream.py | 9 ++++++++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 1ae9f05230..34fb948000 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -478,7 +478,14 @@ if __name__ == "__main__": rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) - model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) + if model_args.model_name_or_path: + model = FlaxAutoModelForMaskedLM.from_pretrained( + model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) + ) + else: + model = FlaxAutoModelForMaskedLM.from_config( + config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) + ) # Store some constant num_epochs = int(training_args.num_train_epochs) diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index 09d6753a48..c8b6fa3da6 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -588,7 +588,12 @@ if __name__ == "__main__": rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) - model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) + if model_args.model_name_or_path: + model = FlaxT5ForConditionalGeneration.from_pretrained( + model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) + ) + else: + model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) # Data collator # This one will take care of randomly masking the tokens. diff --git a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py index 48de9380df..1e2c29947e 100755 --- a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py +++ b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py @@ -427,7 +427,14 @@ if __name__ == "__main__": rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) - model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) + if model_args.model_name_or_path: + model = FlaxAutoModelForMaskedLM.from_pretrained( + model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) + ) + else: + model = FlaxAutoModelForMaskedLM.from_config( + config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) + ) # Store some constant num_epochs = int(training_args.num_train_epochs)