Add ByT5 option to example run_t5_mlm_flax.py (#12634)
* Allow ByT5 type in Flax T5 script * use T5TokenizerFast * change up tokenizer config * model_args * reorder imports * Update run_t5_mlm_flax.py
This commit is contained in:
@@ -42,12 +42,12 @@ from flax.training.common_utils import get_metrics, onehot, shard
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
AutoTokenizer,
|
||||||
BatchEncoding,
|
BatchEncoding,
|
||||||
FlaxT5ForConditionalGeneration,
|
FlaxT5ForConditionalGeneration,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
T5Config,
|
T5Config,
|
||||||
T5TokenizerFast,
|
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
set_seed,
|
set_seed,
|
||||||
@@ -477,11 +477,11 @@ if __name__ == "__main__":
|
|||||||
# Load pretrained model and tokenizer
|
# Load pretrained model and tokenizer
|
||||||
|
|
||||||
if model_args.tokenizer_name:
|
if model_args.tokenizer_name:
|
||||||
tokenizer = T5TokenizerFast.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||||
)
|
)
|
||||||
elif model_args.model_name_or_path:
|
elif model_args.model_name_or_path:
|
||||||
tokenizer = T5TokenizerFast.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user