Allow flax subfolder (#19902)
* add first generation tutorial * [Flax] Add subfolder functionality * [Flax] Add subfolder functionality * up * finish * delete file and re-add test
This commit is contained in:
committed by
GitHub
parent
7a1c68a845
commit
6d023270f6
@@ -1221,3 +1221,68 @@ class FlaxModelPushToHubTester(unittest.TestCase):
|
||||
for key in base_params.keys():
|
||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
|
||||
def check_models_equal(model1, model2):
|
||||
models_are_equal = True
|
||||
flat_params_1 = flatten_dict(model1.params)
|
||||
flat_params_2 = flatten_dict(model2.params)
|
||||
for key in flat_params_1.keys():
|
||||
if np.sum(np.abs(flat_params_1[key] - flat_params_2[key])) > 1e-4:
|
||||
models_are_equal = False
|
||||
|
||||
return models_are_equal
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxModelUtilsTest(unittest.TestCase):
|
||||
def test_model_from_pretrained_subfolder(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
model = FlaxBertModel(config)
|
||||
|
||||
subfolder = "bert"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(os.path.join(tmp_dir, subfolder))
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
_ = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_from_pretrained_subfolder_sharded(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
model = FlaxBertModel(config)
|
||||
|
||||
subfolder = "bert"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
_ = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_from_pretrained_hub_subfolder(self):
|
||||
subfolder = "bert"
|
||||
model_id = "hf-internal-testing/tiny-random-bert-subfolder"
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
_ = FlaxBertModel.from_pretrained(model_id)
|
||||
|
||||
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_model_from_pretrained_hub_subfolder_sharded(self):
|
||||
subfolder = "bert"
|
||||
model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
|
||||
with self.assertRaises(OSError):
|
||||
_ = FlaxBertModel.from_pretrained(model_id)
|
||||
|
||||
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
Reference in New Issue
Block a user