Adapt repository creation to latest hf_hub (#21158)

* Adapt repository creation to latest hf_hub

* Update all examples

* Fix other tests, add Flax examples

* Address review comments
This commit is contained in:
Sylvain Gugger
2023-01-18 17:14:00 +01:00
committed by GitHub
parent 32525428e1
commit 05e72aa0c4
30 changed files with 83 additions and 73 deletions

View File

@@ -45,7 +45,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
from huggingface_hub import Repository, create_repo
from transformers import (
AutoFeatureExtractor,
AutoTokenizer,
@@ -430,7 +430,8 @@ def main():
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/

View File

@@ -45,7 +45,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from huggingface_hub import Repository, create_repo
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
@@ -502,7 +502,8 @@ def main():
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/

View File

@@ -46,7 +46,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
from huggingface_hub import Repository, create_repo
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
@@ -376,7 +376,8 @@ def main():
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/

View File

@@ -46,7 +46,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from huggingface_hub import Repository, create_repo
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
@@ -416,7 +416,8 @@ def main():
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/

View File

@@ -45,7 +45,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from huggingface_hub import Repository, create_repo
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
@@ -542,7 +542,8 @@ def main():
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/

View File

@@ -44,7 +44,7 @@ from flax import struct, traverse_util
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from huggingface_hub import Repository, create_repo
from transformers import (
AutoConfig,
AutoTokenizer,
@@ -467,7 +467,8 @@ def main():
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# region Load Data
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)

View File

@@ -46,7 +46,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
from huggingface_hub import Repository, create_repo
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
@@ -450,7 +450,8 @@ def main():
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/

View File

@@ -39,7 +39,7 @@ from flax import struct, traverse_util
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from huggingface_hub import Repository, create_repo
from transformers import (
AutoConfig,
AutoTokenizer,
@@ -350,7 +350,8 @@ def main():
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).

View File

@@ -41,7 +41,7 @@ from flax import struct, traverse_util
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from huggingface_hub import Repository, create_repo
from transformers import (
AutoConfig,
AutoTokenizer,
@@ -406,7 +406,8 @@ def main():
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/

View File

@@ -43,7 +43,7 @@ from flax import jax_utils
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
from huggingface_hub import Repository, create_repo
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
@@ -298,7 +298,8 @@ def main():
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Initialize datasets and pre-processing transforms
# We use torchvision here for faster pre-processing