[Flax] Fix PyTorch import error (#11839)

* fix_torch_device_generate_test

* remove @

* change pytorch import to flax import
This commit is contained in:
Patrick von Platen
2021-05-24 10:41:10 +01:00
committed by GitHub
parent 0cbddfb190
commit f580604157

View File

@@ -42,7 +42,7 @@ from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from transformers import (
CONFIG_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
AutoConfig,
AutoTokenizer,
FlaxAutoModelForMaskedLM,
@@ -71,7 +71,7 @@ else:
)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)