Allow flax subfolder (#19902)
* add first generation tutorial * [Flax] Add subfolder functionality * [Flax] Add subfolder functionality * up * finish * delete file and re-add test
This commit is contained in:
committed by
GitHub
parent
7a1c68a845
commit
6d023270f6
@@ -636,6 +636,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
_commit_hash=commit_hash,
|
||||
@@ -659,22 +660,24 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
||||
elif from_pt and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
|
||||
):
|
||||
# Load from a sharded pytorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
|
||||
# Load from a Flax checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME)):
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)):
|
||||
# Load from a sharded Flax checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME)
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
|
||||
@@ -685,7 +688,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
||||
f"{pretrained_model_name_or_path}."
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
is_local = True
|
||||
elif is_remote_url(pretrained_model_name_or_path):
|
||||
@@ -786,6 +789,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
|
||||
@@ -1221,3 +1221,68 @@ class FlaxModelPushToHubTester(unittest.TestCase):
|
||||
for key in base_params.keys():
|
||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
|
||||
def check_models_equal(model1, model2):
|
||||
models_are_equal = True
|
||||
flat_params_1 = flatten_dict(model1.params)
|
||||
flat_params_2 = flatten_dict(model2.params)
|
||||
for key in flat_params_1.keys():
|
||||
if np.sum(np.abs(flat_params_1[key] - flat_params_2[key])) > 1e-4:
|
||||
models_are_equal = False
|
||||
|
||||
return models_are_equal
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxModelUtilsTest(unittest.TestCase):
|
||||
def test_model_from_pretrained_subfolder(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
model = FlaxBertModel(config)
|
||||
|
||||
subfolder = "bert"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(os.path.join(tmp_dir, subfolder))
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
_ = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_from_pretrained_subfolder_sharded(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
model = FlaxBertModel(config)
|
||||
|
||||
subfolder = "bert"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
_ = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_from_pretrained_hub_subfolder(self):
|
||||
subfolder = "bert"
|
||||
model_id = "hf-internal-testing/tiny-random-bert-subfolder"
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
_ = FlaxBertModel.from_pretrained(model_id)
|
||||
|
||||
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_model_from_pretrained_hub_subfolder_sharded(self):
|
||||
subfolder = "bert"
|
||||
model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
|
||||
with self.assertRaises(OSError):
|
||||
_ = FlaxBertModel.from_pretrained(model_id)
|
||||
|
||||
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
Reference in New Issue
Block a user