Add use_auth to load_datasets for private datasets to PT and TF examples (#16521)
* fix formatting and remove use_auth * Add use_auth_token to Flax examples
This commit is contained in:
@@ -154,6 +154,13 @@ class ModelArguments:
|
||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
||||
},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -315,6 +322,7 @@ def main():
|
||||
num_labels=len(train_dataset.classes),
|
||||
image_size=data_args.image_size,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
elif model_args.model_name_or_path:
|
||||
config = AutoConfig.from_pretrained(
|
||||
@@ -322,6 +330,7 @@ def main():
|
||||
num_labels=len(train_dataset.classes),
|
||||
image_size=data_args.image_size,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
else:
|
||||
config = CONFIG_MAPPING[model_args.model_type]()
|
||||
@@ -329,11 +338,18 @@ def main():
|
||||
|
||||
if model_args.model_name_or_path:
|
||||
model = FlaxAutoModelForImageClassification.from_pretrained(
|
||||
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
||||
model_args.model_name_or_path,
|
||||
config=config,
|
||||
seed=training_args.seed,
|
||||
dtype=getattr(jnp, model_args.dtype),
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
else:
|
||||
model = FlaxAutoModelForImageClassification.from_config(
|
||||
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
||||
config,
|
||||
seed=training_args.seed,
|
||||
dtype=getattr(jnp, model_args.dtype),
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
# Store some constant
|
||||
|
||||
Reference in New Issue
Block a user