Add token arugment in example scripts (#25172)

* fix

* fix

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-08-02 11:17:31 +02:00
committed by GitHub
parent c6a8768dab
commit 149cb0cce2
43 changed files with 987 additions and 420 deletions

View File

@@ -16,6 +16,7 @@
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional
@@ -133,15 +134,21 @@ class ModelArguments:
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
use_auth_token: bool = field(
default=False,
token: str = field(
default=None,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
},
)
mask_ratio: float = field(
default=0.75, metadata={"help": "The ratio of the number of masked tokens in the input sequence."}
)
@@ -175,6 +182,12 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if model_args.use_auth_token is not None:
warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_mae", model_args, data_args)
@@ -224,7 +237,7 @@ def main():
data_args.dataset_config_name,
data_files=data_args.data_files,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)
# If we don't have a validation split, split off a percentage of train as validation.
@@ -242,7 +255,7 @@ def main():
config_kwargs = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
"token": model_args.token,
}
if model_args.config_name:
config = ViTMAEConfig.from_pretrained(model_args.config_name, **config_kwargs)
@@ -280,7 +293,7 @@ def main():
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
token=model_args.token,
)
else:
logger.info("Training new model from scratch")