From 5803a2a7ac87dc9b767454ffc53c17422d2b9f24 Mon Sep 17 00:00:00 2001 From: Nick Doiron Date: Tue, 13 Jul 2021 08:39:57 -0400 Subject: [PATCH] Add ByT5 option to example run_t5_mlm_flax.py (#12634) * Allow ByT5 type in Flax T5 script * use T5TokenizerFast * change up tokenizer config * model_args * reorder imports * Update run_t5_mlm_flax.py --- examples/flax/language-modeling/run_t5_mlm_flax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index 001bea329a..c206d76bec 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -42,12 +42,12 @@ from flax.training.common_utils import get_metrics, onehot, shard from transformers import ( CONFIG_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING, + AutoTokenizer, BatchEncoding, FlaxT5ForConditionalGeneration, HfArgumentParser, PreTrainedTokenizerBase, T5Config, - T5TokenizerFast, TrainingArguments, is_tensorboard_available, set_seed, @@ -477,11 +477,11 @@ if __name__ == "__main__": # Load pretrained model and tokenizer if model_args.tokenizer_name: - tokenizer = T5TokenizerFast.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer ) elif model_args.model_name_or_path: - tokenizer = T5TokenizerFast.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer ) else: