Clean push to hub API (#12187)
* Clean push to hub API * Create working dir if it does not exist * Different tweak * New API + all models + test Flax * Adds the Trainer clean up * Update src/transformers/file_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Address review comments * (nit) output types * No need to set clone_from when folder exists * Update src/transformers/trainer.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * Add generated_from_trainer tag * Update to new version * Fixes Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Julien Chaumond <julien@huggingface.co> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -16,14 +16,25 @@ import copy
|
||||
import inspect
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import is_flax_available, is_torch_available
|
||||
from huggingface_hub import HfApi
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import BertConfig, FlaxBertModel, is_flax_available, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
PASS,
|
||||
USER,
|
||||
is_pt_flax_cross_test,
|
||||
is_staging_test,
|
||||
require_flax,
|
||||
slow,
|
||||
)
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@@ -504,3 +515,65 @@ class FlaxModelTesterMixin:
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
|
||||
@require_flax
|
||||
@is_staging_test
|
||||
class FlaxModelPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
|
||||
cls._token = cls._api.login(username=USER, password=PASS)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-model-flax")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-model-flax-org", organization="valid_org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = FlaxBertModel(config)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-model-flax"), push_to_hub=True, use_auth_token=self._token
|
||||
)
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
|
||||
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
new_params = flatten_dict(unfreeze(new_model.params))
|
||||
|
||||
for key in base_params.keys():
|
||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = FlaxBertModel(config)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-model-flax-org"),
|
||||
push_to_hub=True,
|
||||
use_auth_token=self._token,
|
||||
organization="valid_org",
|
||||
)
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
|
||||
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
new_params = flatten_dict(unfreeze(new_model.params))
|
||||
|
||||
for key in base_params.keys():
|
||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
Reference in New Issue
Block a user