[Flax] Fix PyTorch import error (#11839)
* fix_torch_device_generate_test * remove @ * change pytorch import to flax import
This commit is contained in:
committed by
GitHub
parent
0cbddfb190
commit
f580604157
@@ -42,7 +42,7 @@ from flax.training import train_state
|
|||||||
from flax.training.common_utils import get_metrics, onehot, shard
|
from flax.training.common_utils import get_metrics, onehot, shard
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_FOR_MASKED_LM_MAPPING,
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
FlaxAutoModelForMaskedLM,
|
FlaxAutoModelForMaskedLM,
|
||||||
@@ -71,7 +71,7 @@ else:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
|
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user