From 7f20bf0d438eb5a0322251f5dc405faff9dfea18 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 11 Nov 2021 15:34:00 +0000 Subject: [PATCH] Fixing requirements for TF LM models and use correct model mappings (#14372) * Fixing requirements for TF LM models and use correct model mappings * make style --- examples/tensorflow/language-modeling/requirements.txt | 2 ++ examples/tensorflow/language-modeling/run_clm.py | 6 +++--- examples/tensorflow/language-modeling/run_mlm.py | 6 +++--- 3 files changed, 8 insertions(+), 6 deletions(-) create mode 100644 examples/tensorflow/language-modeling/requirements.txt diff --git a/examples/tensorflow/language-modeling/requirements.txt b/examples/tensorflow/language-modeling/requirements.txt new file mode 100644 index 0000000000..c4ae4890d2 --- /dev/null +++ b/examples/tensorflow/language-modeling/requirements.txt @@ -0,0 +1,2 @@ +datasets >= 1.8.0 +sentencepiece != 0.1.92 \ No newline at end of file diff --git a/examples/tensorflow/language-modeling/run_clm.py b/examples/tensorflow/language-modeling/run_clm.py index 05437d37ce..5f1adc5ccf 100755 --- a/examples/tensorflow/language-modeling/run_clm.py +++ b/examples/tensorflow/language-modeling/run_clm.py @@ -43,8 +43,8 @@ import transformers from transformers import ( CONFIG_MAPPING, CONFIG_NAME, - MODEL_FOR_CAUSAL_LM_MAPPING, TF2_WEIGHTS_NAME, + TF_MODEL_FOR_CAUSAL_LM_MAPPING, AutoConfig, AutoTokenizer, HfArgumentParser, @@ -57,8 +57,8 @@ from transformers.utils.versions import require_version logger = logging.getLogger(__name__) -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") -MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) +require_version("datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/language-modeling/requirements.txt") +MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) # endregion diff --git a/examples/tensorflow/language-modeling/run_mlm.py b/examples/tensorflow/language-modeling/run_mlm.py index ebfb165b7e..244a3a9a47 100755 --- a/examples/tensorflow/language-modeling/run_mlm.py +++ b/examples/tensorflow/language-modeling/run_mlm.py @@ -45,8 +45,8 @@ import transformers from transformers import ( CONFIG_MAPPING, CONFIG_NAME, - MODEL_FOR_MASKED_LM_MAPPING, TF2_WEIGHTS_NAME, + TF_MODEL_FOR_MASKED_LM_MAPPING, AutoConfig, AutoTokenizer, HfArgumentParser, @@ -59,8 +59,8 @@ from transformers.utils.versions import require_version logger = logging.getLogger(__name__) -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") -MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) +require_version("datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/language-modeling/requirements.txt") +MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_MASKED_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)