[From pretrained] Allow download from subfolder inside model repo (#18184)
* add first generation tutorial * [from_pretrained] Allow loading models from subfolders * remove gen file * add doc strings * allow download from subfolder * add tests * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * apply comments * correct doc string Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
ce0152819d
commit
3bb6356d4d
@@ -157,6 +157,17 @@ class ConfigTester(object):
|
||||
|
||||
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||
|
||||
def create_and_test_config_from_and_save_pretrained_subfolder(self):
|
||||
config_first = self.config_class(**self.inputs_dict)
|
||||
|
||||
subfolder = "test"
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sub_tmpdirname = os.path.join(tmpdirname, subfolder)
|
||||
config_first.save_pretrained(sub_tmpdirname)
|
||||
config_second = self.config_class.from_pretrained(tmpdirname, subfolder=subfolder)
|
||||
|
||||
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
|
||||
|
||||
def create_and_test_config_with_num_labels(self):
|
||||
config = self.config_class(**self.inputs_dict, num_labels=5)
|
||||
self.parent.assertEqual(len(config.id2label), 5)
|
||||
@@ -197,6 +208,7 @@ class ConfigTester(object):
|
||||
self.create_and_test_config_to_json_string()
|
||||
self.create_and_test_config_to_json_file()
|
||||
self.create_and_test_config_from_and_save_pretrained()
|
||||
self.create_and_test_config_from_and_save_pretrained_subfolder()
|
||||
self.create_and_test_config_with_num_labels()
|
||||
self.check_config_can_be_init_without_params()
|
||||
self.check_config_arguments_init()
|
||||
@@ -308,6 +320,15 @@ class ConfigTestUtils(unittest.TestCase):
|
||||
f" {', '.join(keys_with_defaults)}."
|
||||
)
|
||||
|
||||
def test_from_pretrained_subfolder(self):
|
||||
with self.assertRaises(OSError):
|
||||
# config is in subfolder, the following should not work without specifying the subfolder
|
||||
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder")
|
||||
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder", subfolder="bert")
|
||||
|
||||
self.assertIsNotNone(config)
|
||||
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
|
||||
@@ -2503,6 +2503,15 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous()
|
||||
|
||||
|
||||
def check_models_equal(model1, model2):
|
||||
models_are_equal = True
|
||||
for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
|
||||
if model1_p.data.ne(model2_p.data).sum() > 0:
|
||||
models_are_equal = False
|
||||
|
||||
return models_are_equal
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModelUtilsTest(TestCasePlus):
|
||||
@slow
|
||||
@@ -2531,6 +2540,56 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
self.assertEqual(model.config, config)
|
||||
|
||||
def test_model_from_pretrained_subfolder(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
model = BertModel(config)
|
||||
|
||||
subfolder = "bert"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(os.path.join(tmp_dir, subfolder))
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
_ = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_from_pretrained_subfolder_sharded(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
model = BertModel(config)
|
||||
|
||||
subfolder = "bert"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
_ = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_from_pretrained_hub_subfolder(self):
|
||||
subfolder = "bert"
|
||||
model_id = "hf-internal-testing/tiny-random-bert-subfolder"
|
||||
with self.assertRaises(OSError):
|
||||
_ = BertModel.from_pretrained(model_id)
|
||||
|
||||
model = BertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_model_from_pretrained_hub_subfolder_sharded(self):
|
||||
subfolder = "bert"
|
||||
model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
|
||||
with self.assertRaises(OSError):
|
||||
_ = BertModel.from_pretrained(model_id)
|
||||
|
||||
model = BertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_model_from_pretrained_with_different_pretrained_model_name(self):
|
||||
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
Reference in New Issue
Block a user