Add Flax dummy objects (#7918)
This commit is contained in:
@@ -841,6 +841,11 @@ else:
|
|||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
from .modeling_flax_bert import FlaxBertModel
|
from .modeling_flax_bert import FlaxBertModel
|
||||||
from .modeling_flax_roberta import FlaxRobertaModel
|
from .modeling_flax_roberta import FlaxRobertaModel
|
||||||
|
else:
|
||||||
|
# Import the same objects as dummies to get them in the namespace.
|
||||||
|
# They will raise an import error if the user tries to instantiate / use them.
|
||||||
|
from .utils.dummy_flax_objects import *
|
||||||
|
|
||||||
|
|
||||||
if not is_tf_available() and not is_torch_available():
|
if not is_tf_available() and not is_torch_available():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -356,6 +356,12 @@ installation page: https://www.tensorflow.org/install and follow the ones that m
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_IMPORT_ERROR = """
|
||||||
|
{0} requires the FLAX library but it was not found in your enviromnent. Checkout the instructions on the
|
||||||
|
installation page: https://github.com/google/flax and follow the ones that match your enviromnent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def requires_datasets(obj):
|
def requires_datasets(obj):
|
||||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||||
if not is_datasets_available():
|
if not is_datasets_available():
|
||||||
@@ -386,6 +392,12 @@ def requires_tf(obj):
|
|||||||
raise ImportError(TENSORFLOW_IMPORT_ERROR.format(name))
|
raise ImportError(TENSORFLOW_IMPORT_ERROR.format(name))
|
||||||
|
|
||||||
|
|
||||||
|
def requires_flax(obj):
|
||||||
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||||
|
if not is_flax_available():
|
||||||
|
raise ImportError(FLAX_IMPORT_ERROR.format(name))
|
||||||
|
|
||||||
|
|
||||||
def requires_tokenizers(obj):
|
def requires_tokenizers(obj):
|
||||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||||
if not is_tokenizers_available():
|
if not is_tokenizers_available():
|
||||||
|
|||||||
20
src/transformers/utils/dummy_flax_objects.py
Normal file
20
src/transformers/utils/dummy_flax_objects.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||||
|
from ..file_utils import requires_flax
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxRobertaModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
@@ -72,6 +72,28 @@ def {0}(*args, **kwargs):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
DUMMY_FLAX_PRETRAINED_CLASS = """
|
||||||
|
class {0}:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
"""
|
||||||
|
|
||||||
|
DUMMY_FLAX_CLASS = """
|
||||||
|
class {0}:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
"""
|
||||||
|
|
||||||
|
DUMMY_FLAX_FUNCTION = """
|
||||||
|
def {0}(*args, **kwargs):
|
||||||
|
requires_flax({0})
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
DUMMY_SENTENCEPIECE_PRETRAINED_CLASS = """
|
DUMMY_SENTENCEPIECE_PRETRAINED_CLASS = """
|
||||||
class {0}:
|
class {0}:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@@ -120,6 +142,7 @@ def {0}(*args, **kwargs):
|
|||||||
DUMMY_PRETRAINED_CLASS = {
|
DUMMY_PRETRAINED_CLASS = {
|
||||||
"pt": DUMMY_PT_PRETRAINED_CLASS,
|
"pt": DUMMY_PT_PRETRAINED_CLASS,
|
||||||
"tf": DUMMY_TF_PRETRAINED_CLASS,
|
"tf": DUMMY_TF_PRETRAINED_CLASS,
|
||||||
|
"flax": DUMMY_FLAX_PRETRAINED_CLASS,
|
||||||
"sentencepiece": DUMMY_SENTENCEPIECE_PRETRAINED_CLASS,
|
"sentencepiece": DUMMY_SENTENCEPIECE_PRETRAINED_CLASS,
|
||||||
"tokenizers": DUMMY_TOKENIZERS_PRETRAINED_CLASS,
|
"tokenizers": DUMMY_TOKENIZERS_PRETRAINED_CLASS,
|
||||||
}
|
}
|
||||||
@@ -127,6 +150,7 @@ DUMMY_PRETRAINED_CLASS = {
|
|||||||
DUMMY_CLASS = {
|
DUMMY_CLASS = {
|
||||||
"pt": DUMMY_PT_CLASS,
|
"pt": DUMMY_PT_CLASS,
|
||||||
"tf": DUMMY_TF_CLASS,
|
"tf": DUMMY_TF_CLASS,
|
||||||
|
"flax": DUMMY_FLAX_CLASS,
|
||||||
"sentencepiece": DUMMY_SENTENCEPIECE_CLASS,
|
"sentencepiece": DUMMY_SENTENCEPIECE_CLASS,
|
||||||
"tokenizers": DUMMY_TOKENIZERS_CLASS,
|
"tokenizers": DUMMY_TOKENIZERS_CLASS,
|
||||||
}
|
}
|
||||||
@@ -134,6 +158,7 @@ DUMMY_CLASS = {
|
|||||||
DUMMY_FUNCTION = {
|
DUMMY_FUNCTION = {
|
||||||
"pt": DUMMY_PT_FUNCTION,
|
"pt": DUMMY_PT_FUNCTION,
|
||||||
"tf": DUMMY_TF_FUNCTION,
|
"tf": DUMMY_TF_FUNCTION,
|
||||||
|
"flax": DUMMY_FLAX_FUNCTION,
|
||||||
"sentencepiece": DUMMY_SENTENCEPIECE_FUNCTION,
|
"sentencepiece": DUMMY_SENTENCEPIECE_FUNCTION,
|
||||||
"tokenizers": DUMMY_TOKENIZERS_FUNCTION,
|
"tokenizers": DUMMY_TOKENIZERS_FUNCTION,
|
||||||
}
|
}
|
||||||
@@ -208,7 +233,24 @@ def read_init():
|
|||||||
elif line.startswith(" "):
|
elif line.startswith(" "):
|
||||||
tf_objects.append(line[8:-2])
|
tf_objects.append(line[8:-2])
|
||||||
line_index += 1
|
line_index += 1
|
||||||
return sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects
|
|
||||||
|
# Find where the FLAX imports begin
|
||||||
|
flax_objects = []
|
||||||
|
while not lines[line_index].startswith("if is_flax_available():"):
|
||||||
|
line_index += 1
|
||||||
|
line_index += 1
|
||||||
|
|
||||||
|
# Until we unindent, add PyTorch objects to the list
|
||||||
|
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
||||||
|
line = lines[line_index]
|
||||||
|
search = _re_single_line_import.search(line)
|
||||||
|
if search is not None:
|
||||||
|
flax_objects += search.groups()[0].split(", ")
|
||||||
|
elif line.startswith(" "):
|
||||||
|
flax_objects.append(line[8:-2])
|
||||||
|
line_index += 1
|
||||||
|
|
||||||
|
return sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_object(name, type="pt"):
|
def create_dummy_object(name, type="pt"):
|
||||||
@@ -224,7 +266,7 @@ def create_dummy_object(name, type="pt"):
|
|||||||
"Model",
|
"Model",
|
||||||
"Tokenizer",
|
"Tokenizer",
|
||||||
]
|
]
|
||||||
assert type in ["pt", "tf", "sentencepiece", "tokenizers"]
|
assert type in ["pt", "tf", "sentencepiece", "tokenizers", "flax"]
|
||||||
if name.isupper():
|
if name.isupper():
|
||||||
return DUMMY_CONSTANT.format(name)
|
return DUMMY_CONSTANT.format(name)
|
||||||
elif name.islower():
|
elif name.islower():
|
||||||
@@ -244,7 +286,7 @@ def create_dummy_object(name, type="pt"):
|
|||||||
|
|
||||||
def create_dummy_files():
|
def create_dummy_files():
|
||||||
""" Create the content of the dummy files. """
|
""" Create the content of the dummy files. """
|
||||||
sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects = read_init()
|
sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects = read_init()
|
||||||
|
|
||||||
sentencepiece_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
sentencepiece_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||||
sentencepiece_dummies += "from ..file_utils import requires_sentencepiece\n\n"
|
sentencepiece_dummies += "from ..file_utils import requires_sentencepiece\n\n"
|
||||||
@@ -262,17 +304,22 @@ def create_dummy_files():
|
|||||||
tf_dummies += "from ..file_utils import requires_tf\n\n"
|
tf_dummies += "from ..file_utils import requires_tf\n\n"
|
||||||
tf_dummies += "\n".join([create_dummy_object(o, type="tf") for o in tf_objects])
|
tf_dummies += "\n".join([create_dummy_object(o, type="tf") for o in tf_objects])
|
||||||
|
|
||||||
return sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies
|
flax_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||||
|
flax_dummies += "from ..file_utils import requires_flax\n\n"
|
||||||
|
flax_dummies += "\n".join([create_dummy_object(o, type="flax") for o in flax_objects])
|
||||||
|
|
||||||
|
return sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies
|
||||||
|
|
||||||
|
|
||||||
def check_dummies(overwrite=False):
|
def check_dummies(overwrite=False):
|
||||||
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """
|
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """
|
||||||
sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies = create_dummy_files()
|
sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies = create_dummy_files()
|
||||||
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
|
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
|
||||||
sentencepiece_file = os.path.join(path, "dummy_sentencepiece_objects.py")
|
sentencepiece_file = os.path.join(path, "dummy_sentencepiece_objects.py")
|
||||||
tokenizers_file = os.path.join(path, "dummy_tokenizers_objects.py")
|
tokenizers_file = os.path.join(path, "dummy_tokenizers_objects.py")
|
||||||
pt_file = os.path.join(path, "dummy_pt_objects.py")
|
pt_file = os.path.join(path, "dummy_pt_objects.py")
|
||||||
tf_file = os.path.join(path, "dummy_tf_objects.py")
|
tf_file = os.path.join(path, "dummy_tf_objects.py")
|
||||||
|
flax_file = os.path.join(path, "dummy_flax_objects.py")
|
||||||
|
|
||||||
with open(sentencepiece_file, "r", encoding="utf-8") as f:
|
with open(sentencepiece_file, "r", encoding="utf-8") as f:
|
||||||
actual_sentencepiece_dummies = f.read()
|
actual_sentencepiece_dummies = f.read()
|
||||||
@@ -282,6 +329,8 @@ def check_dummies(overwrite=False):
|
|||||||
actual_pt_dummies = f.read()
|
actual_pt_dummies = f.read()
|
||||||
with open(tf_file, "r", encoding="utf-8") as f:
|
with open(tf_file, "r", encoding="utf-8") as f:
|
||||||
actual_tf_dummies = f.read()
|
actual_tf_dummies = f.read()
|
||||||
|
with open(flax_file, "r", encoding="utf-8") as f:
|
||||||
|
actual_flax_dummies = f.read()
|
||||||
|
|
||||||
if sentencepiece_dummies != actual_sentencepiece_dummies:
|
if sentencepiece_dummies != actual_sentencepiece_dummies:
|
||||||
if overwrite:
|
if overwrite:
|
||||||
@@ -327,6 +376,17 @@ def check_dummies(overwrite=False):
|
|||||||
"Run `make fix-copies` to fix this.",
|
"Run `make fix-copies` to fix this.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if flax_dummies != actual_flax_dummies:
|
||||||
|
if overwrite:
|
||||||
|
print("Updating transformers.utils.dummy_flax_objects.py as the main __init__ has new objects.")
|
||||||
|
with open(flax_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(flax_dummies)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The main __init__ has objects that are not present in transformers.utils.dummy_flax_objects.py.",
|
||||||
|
"Run `make fix-copies` to fix this.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|||||||
Reference in New Issue
Block a user