From 30ca5299022f424d8a797293fb55169dd4b1f02c Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 2 May 2022 13:47:47 -0300 Subject: [PATCH] Make the sacremoses dependency optional (#17049) * Make sacremoses optional * Pickle --- setup.py | 3 +- src/transformers/dependency_versions_check.py | 2 +- .../models/fsmt/tokenization_fsmt.py | 36 ++++++++++++++++--- .../models/xlm/tokenization_xlm.py | 34 +++++++++++++++--- 4 files changed, 63 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index c8f959e43e..cc451b190d 100644 --- a/setup.py +++ b/setup.py @@ -288,6 +288,7 @@ extras["testing"] = ( "nltk", "GitPython", "hf-doc-builder", + 'sacremoses' ) + extras["retrieval"] + extras["modelcreation"] @@ -365,7 +366,6 @@ extras["torchhub"] = deps_list( "protobuf", "regex", "requests", - "sacremoses", "sentencepiece", "torch", "tokenizers", @@ -383,7 +383,6 @@ install_requires = [ deps["pyyaml"], # used for the model cards metadata deps["regex"], # for OpenAI GPT deps["requests"], # for downloading models over HTTPS - deps["sacremoses"], # for XLM deps["tokenizers"], deps["tqdm"], # progress bars in model download and training scripts ] diff --git a/src/transformers/dependency_versions_check.py b/src/transformers/dependency_versions_check.py index bd3b65f724..bbf863222a 100644 --- a/src/transformers/dependency_versions_check.py +++ b/src/transformers/dependency_versions_check.py @@ -23,7 +23,7 @@ from .utils.versions import require_version, require_version_core # order specific notes: # - tqdm must be checked before tokenizers -pkgs_to_check_at_runtime = "python tqdm regex sacremoses requests packaging filelock numpy tokenizers".split() +pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split() if sys.version_info < (3, 7): pkgs_to_check_at_runtime.append("dataclasses") if sys.version_info < (3, 8): diff --git a/src/transformers/models/fsmt/tokenization_fsmt.py b/src/transformers/models/fsmt/tokenization_fsmt.py index 2d11361439..ea5131f2b4 100644 --- a/src/transformers/models/fsmt/tokenization_fsmt.py +++ b/src/transformers/models/fsmt/tokenization_fsmt.py @@ -21,8 +21,6 @@ import re import unicodedata from typing import Dict, List, Optional, Tuple -import sacremoses as sm - from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging @@ -212,6 +210,16 @@ class FSMTTokenizer(PreTrainedTokenizer): **kwargs, ) + try: + import sacremoses + except ImportError: + raise ImportError( + "You need to install sacremoses to use XLMTokenizer. " + "See https://pypi.org/project/sacremoses/ for installation." + ) + + self.sm = sacremoses + self.src_vocab_file = src_vocab_file self.tgt_vocab_file = tgt_vocab_file self.merges_file = merges_file @@ -254,13 +262,13 @@ class FSMTTokenizer(PreTrainedTokenizer): def moses_punct_norm(self, text, lang): if lang not in self.cache_moses_punct_normalizer: - punct_normalizer = sm.MosesPunctNormalizer(lang=lang) + punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang) self.cache_moses_punct_normalizer[lang] = punct_normalizer return self.cache_moses_punct_normalizer[lang].normalize(text) def moses_tokenize(self, text, lang): if lang not in self.cache_moses_tokenizer: - moses_tokenizer = sm.MosesTokenizer(lang=lang) + moses_tokenizer = self.sm.MosesTokenizer(lang=lang) self.cache_moses_tokenizer[lang] = moses_tokenizer return self.cache_moses_tokenizer[lang].tokenize( text, aggressive_dash_splits=True, return_str=False, escape=True @@ -268,7 +276,7 @@ class FSMTTokenizer(PreTrainedTokenizer): def moses_detokenize(self, tokens, lang): if lang not in self.cache_moses_tokenizer: - moses_detokenizer = sm.MosesDetokenizer(lang=self.tgt_lang) + moses_detokenizer = self.sm.MosesDetokenizer(lang=self.tgt_lang) self.cache_moses_detokenizer[lang] = moses_detokenizer return self.cache_moses_detokenizer[lang].detokenize(tokens) @@ -516,3 +524,21 @@ class FSMTTokenizer(PreTrainedTokenizer): index += 1 return src_vocab_file, tgt_vocab_file, merges_file + + def __getstate__(self): + state = self.__dict__.copy() + state["sm"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + try: + import sacremoses + except ImportError: + raise ImportError( + "You need to install sacremoses to use XLMTokenizer. " + "See https://pypi.org/project/sacremoses/ for installation." + ) + + self.sm = sacremoses diff --git a/src/transformers/models/xlm/tokenization_xlm.py b/src/transformers/models/xlm/tokenization_xlm.py index 7519a514c9..f6c94f11ae 100644 --- a/src/transformers/models/xlm/tokenization_xlm.py +++ b/src/transformers/models/xlm/tokenization_xlm.py @@ -22,8 +22,6 @@ import sys import unicodedata from typing import List, Optional, Tuple -import sacremoses as sm - from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging @@ -629,6 +627,16 @@ class XLMTokenizer(PreTrainedTokenizer): **kwargs, ) + try: + import sacremoses + except ImportError: + raise ImportError( + "You need to install sacremoses to use XLMTokenizer. " + "See https://pypi.org/project/sacremoses/ for installation." + ) + + self.sm = sacremoses + # cache of sm.MosesPunctNormalizer instance self.cache_moses_punct_normalizer = dict() # cache of sm.MosesTokenizer instance @@ -659,7 +667,7 @@ class XLMTokenizer(PreTrainedTokenizer): def moses_punct_norm(self, text, lang): if lang not in self.cache_moses_punct_normalizer: - punct_normalizer = sm.MosesPunctNormalizer(lang=lang) + punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang) self.cache_moses_punct_normalizer[lang] = punct_normalizer else: punct_normalizer = self.cache_moses_punct_normalizer[lang] @@ -667,7 +675,7 @@ class XLMTokenizer(PreTrainedTokenizer): def moses_tokenize(self, text, lang): if lang not in self.cache_moses_tokenizer: - moses_tokenizer = sm.MosesTokenizer(lang=lang) + moses_tokenizer = self.sm.MosesTokenizer(lang=lang) self.cache_moses_tokenizer[lang] = moses_tokenizer else: moses_tokenizer = self.cache_moses_tokenizer[lang] @@ -970,3 +978,21 @@ class XLMTokenizer(PreTrainedTokenizer): index += 1 return vocab_file, merge_file + + def __getstate__(self): + state = self.__dict__.copy() + state["sm"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + try: + import sacremoses + except ImportError: + raise ImportError( + "You need to install sacremoses to use XLMTokenizer. " + "See https://pypi.org/project/sacremoses/ for installation." + ) + + self.sm = sacremoses