Allow flax subfolder (#19902)
* add first generation tutorial * [Flax] Add subfolder functionality * [Flax] Add subfolder functionality * up * finish * delete file and re-add test
This commit is contained in:
committed by
GitHub
parent
7a1c68a845
commit
6d023270f6
@@ -636,6 +636,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
_from_auto=from_auto_class,
|
_from_auto=from_auto_class,
|
||||||
_from_pipeline=from_pipeline,
|
_from_pipeline=from_pipeline,
|
||||||
_commit_hash=commit_hash,
|
_commit_hash=commit_hash,
|
||||||
@@ -659,22 +660,24 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
||||||
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
|
elif from_pt and os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
|
||||||
|
):
|
||||||
# Load from a sharded pytorch checkpoint
|
# Load from a sharded pytorch checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
|
||||||
# Load from a Flax checkpoint
|
# Load from a Flax checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)):
|
||||||
# Load from a sharded Flax checkpoint
|
# Load from a sharded Flax checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
# At this stage we don't have a weight file so we will raise an error.
|
# At this stage we don't have a weight file so we will raise an error.
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||||
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
|
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
|
||||||
@@ -685,7 +688,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
||||||
f"{pretrained_model_name_or_path}."
|
f"{pretrained_model_name_or_path}."
|
||||||
)
|
)
|
||||||
elif os.path.isfile(pretrained_model_name_or_path):
|
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
is_local = True
|
is_local = True
|
||||||
elif is_remote_url(pretrained_model_name_or_path):
|
elif is_remote_url(pretrained_model_name_or_path):
|
||||||
@@ -786,6 +789,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
_commit_hash=commit_hash,
|
_commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1221,3 +1221,68 @@ class FlaxModelPushToHubTester(unittest.TestCase):
|
|||||||
for key in base_params.keys():
|
for key in base_params.keys():
|
||||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
|
||||||
|
def check_models_equal(model1, model2):
|
||||||
|
models_are_equal = True
|
||||||
|
flat_params_1 = flatten_dict(model1.params)
|
||||||
|
flat_params_2 = flatten_dict(model2.params)
|
||||||
|
for key in flat_params_1.keys():
|
||||||
|
if np.sum(np.abs(flat_params_1[key] - flat_params_2[key])) > 1e-4:
|
||||||
|
models_are_equal = False
|
||||||
|
|
||||||
|
return models_are_equal
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
class FlaxModelUtilsTest(unittest.TestCase):
|
||||||
|
def test_model_from_pretrained_subfolder(self):
|
||||||
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
model = FlaxBertModel(config)
|
||||||
|
|
||||||
|
subfolder = "bert"
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(os.path.join(tmp_dir, subfolder))
|
||||||
|
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
_ = FlaxBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||||
|
|
||||||
|
self.assertTrue(check_models_equal(model, model_loaded))
|
||||||
|
|
||||||
|
def test_model_from_pretrained_subfolder_sharded(self):
|
||||||
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
model = FlaxBertModel(config)
|
||||||
|
|
||||||
|
subfolder = "bert"
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
|
||||||
|
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
_ = FlaxBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||||
|
|
||||||
|
self.assertTrue(check_models_equal(model, model_loaded))
|
||||||
|
|
||||||
|
def test_model_from_pretrained_hub_subfolder(self):
|
||||||
|
subfolder = "bert"
|
||||||
|
model_id = "hf-internal-testing/tiny-random-bert-subfolder"
|
||||||
|
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
_ = FlaxBertModel.from_pretrained(model_id)
|
||||||
|
|
||||||
|
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||||
|
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def test_model_from_pretrained_hub_subfolder_sharded(self):
|
||||||
|
subfolder = "bert"
|
||||||
|
model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
_ = FlaxBertModel.from_pretrained(model_id)
|
||||||
|
|
||||||
|
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||||
|
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|||||||
Reference in New Issue
Block a user