Fixing requirements for TF LM models and use correct model mappings (#14372)
* Fixing requirements for TF LM models and use correct model mappings * make style
This commit is contained in:
2
examples/tensorflow/language-modeling/requirements.txt
Normal file
2
examples/tensorflow/language-modeling/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
datasets >= 1.8.0
|
||||||
|
sentencepiece != 0.1.92
|
||||||
@@ -43,8 +43,8 @@ import transformers
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
|
||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
|
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
@@ -57,8 +57,8 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/language-modeling/requirements.txt")
|
||||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
|
MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
|
||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|||||||
@@ -45,8 +45,8 @@ import transformers
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
MODEL_FOR_MASKED_LM_MAPPING,
|
|
||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
|
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
@@ -59,8 +59,8 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/language-modeling/requirements.txt")
|
||||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
|
MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user