From 6d023270f6b1f96053e13432db191425e2174a8d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 26 Oct 2022 18:33:23 +0200 Subject: [PATCH] Allow flax subfolder (#19902) * add first generation tutorial * [Flax] Add subfolder functionality * [Flax] Add subfolder functionality * up * finish * delete file and re-add test --- src/transformers/modeling_flax_utils.py | 24 +++++---- tests/test_modeling_flax_common.py | 65 +++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 088856cd19..92eda7ee1b 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -636,6 +636,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, + subfolder=subfolder, _from_auto=from_auto_class, _from_pipeline=from_pipeline, _commit_hash=commit_hash, @@ -659,22 +660,24 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) if os.path.isdir(pretrained_model_name_or_path): - if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): # Load from a PyTorch checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) - elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) + elif from_pt and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME) + ): # Load from a sharded pytorch checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME) is_sharded = True - elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) - elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME)): + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)): # Load from a sharded Flax checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME) + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME) is_sharded = True # At this stage we don't have a weight file so we will raise an error. - elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): raise EnvironmentError( f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " @@ -685,7 +688,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " f"{pretrained_model_name_or_path}." ) - elif os.path.isfile(pretrained_model_name_or_path): + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): archive_file = pretrained_model_name_or_path is_local = True elif is_remote_url(pretrained_model_name_or_path): @@ -786,6 +789,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, + subfolder=subfolder, _commit_hash=commit_hash, ) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 37171e2138..81ae330746 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1221,3 +1221,68 @@ class FlaxModelPushToHubTester(unittest.TestCase): for key in base_params.keys(): max_diff = (base_params[key] - new_params[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + +def check_models_equal(model1, model2): + models_are_equal = True + flat_params_1 = flatten_dict(model1.params) + flat_params_2 = flatten_dict(model2.params) + for key in flat_params_1.keys(): + if np.sum(np.abs(flat_params_1[key] - flat_params_2[key])) > 1e-4: + models_are_equal = False + + return models_are_equal + + +@require_flax +class FlaxModelUtilsTest(unittest.TestCase): + def test_model_from_pretrained_subfolder(self): + config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + model = FlaxBertModel(config) + + subfolder = "bert" + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(os.path.join(tmp_dir, subfolder)) + + with self.assertRaises(OSError): + _ = FlaxBertModel.from_pretrained(tmp_dir) + + model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder) + + self.assertTrue(check_models_equal(model, model_loaded)) + + def test_model_from_pretrained_subfolder_sharded(self): + config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + model = FlaxBertModel(config) + + subfolder = "bert" + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB") + + with self.assertRaises(OSError): + _ = FlaxBertModel.from_pretrained(tmp_dir) + + model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder) + + self.assertTrue(check_models_equal(model, model_loaded)) + + def test_model_from_pretrained_hub_subfolder(self): + subfolder = "bert" + model_id = "hf-internal-testing/tiny-random-bert-subfolder" + + with self.assertRaises(OSError): + _ = FlaxBertModel.from_pretrained(model_id) + + model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder) + + self.assertIsNotNone(model) + + def test_model_from_pretrained_hub_subfolder_sharded(self): + subfolder = "bert" + model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder" + with self.assertRaises(OSError): + _ = FlaxBertModel.from_pretrained(model_id) + + model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder) + + self.assertIsNotNone(model)