From f58060415761e24a31506bde929edea36b890dd4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 24 May 2021 10:41:10 +0100 Subject: [PATCH] [Flax] Fix PyTorch import error (#11839) * fix_torch_device_generate_test * remove @ * change pytorch import to flax import --- examples/flax/language-modeling/run_mlm_flax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 09885524d2..6be1f7ed18 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -42,7 +42,7 @@ from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard from transformers import ( CONFIG_MAPPING, - MODEL_FOR_MASKED_LM_MAPPING, + FLAX_MODEL_FOR_MASKED_LM_MAPPING, AutoConfig, AutoTokenizer, FlaxAutoModelForMaskedLM, @@ -71,7 +71,7 @@ else: ) -MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) +MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)