From 4864d08d3e267e0914ac71606fb5e6cb36134c30 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 26 Oct 2023 12:37:09 +0200 Subject: [PATCH] Add-support for commit description (#26704) * fix * update * revert * add dosctring * good to go * update * add a test --- src/transformers/utils/hub.py | 6 ++++++ tests/test_modeling_utils.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index a3ed744f46..047e00bc31 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -724,6 +724,7 @@ class PushToHubMixin: token: Optional[Union[bool, str]] = None, create_pr: bool = False, revision: str = None, + commit_description: str = None, ): """ Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`. @@ -778,6 +779,7 @@ class PushToHubMixin: repo_id=repo_id, operations=operations, commit_message=commit_message, + commit_description=commit_description, token=token, create_pr=create_pr, revision=revision, @@ -794,6 +796,7 @@ class PushToHubMixin: create_pr: bool = False, safe_serialization: bool = False, revision: str = None, + commit_description: str = None, **deprecated_kwargs, ) -> str: """ @@ -825,6 +828,8 @@ class PushToHubMixin: Whether or not to convert the model weights in safetensors format for safer serialization. revision (`str`, *optional*): Branch to push the uploaded files to. + commit_description (`str`, *optional*): + The description of the commit that will be created Examples: @@ -901,6 +906,7 @@ class PushToHubMixin: token=token, create_pr=create_pr, revision=revision, + commit_description=commit_description, ) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index bccde5af50..2a6246f870 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1119,6 +1119,23 @@ class ModelPushToHubTester(unittest.TestCase): for p1, p2 in zip(model.parameters(), new_model.parameters()): self.assertTrue(torch.equal(p1, p2)) + def test_push_to_hub_with_description(self): + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + model = BertModel(config) + COMMIT_DESCRIPTION = """ +The commit description supports markdown synthax see: +```python +>>> form transformers import AutoConfig +>>> config = AutoConfig.from_pretrained("bert-base-uncased") +``` +""" + commit_details = model.push_to_hub( + "test-model", use_auth_token=self._token, create_pr=True, commit_description=COMMIT_DESCRIPTION + ) + self.assertEqual(commit_details.commit_description, COMMIT_DESCRIPTION) + @unittest.skip("This test is flaky") def test_push_to_hub_in_organization(self): config = BertConfig(