Add use_auth to load_datasets for private datasets to PT and TF examples (#16521)

* fix formatting and remove use_auth

* Add use_auth_token to Flax examples
This commit is contained in:
Karim Foda
2022-04-04 15:27:45 +01:00
committed by GitHub
parent b9a768b3ff
commit 24a85cca61
36 changed files with 544 additions and 92 deletions

View File

@@ -337,7 +337,11 @@ def main():
# download the dataset.
if data_args.task_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset("glue", data_args.task_name)
raw_datasets = load_dataset(
"glue",
data_args.task_name,
use_auth_token=True if model_args.use_auth_token else None,
)
else:
# Loading the dataset from local csv or json file.
data_files = {}
@@ -346,7 +350,11 @@ def main():
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1]
raw_datasets = load_dataset(extension, data_files=data_files)
raw_datasets = load_dataset(
extension,
data_files=data_files,
use_auth_token=True if model_args.use_auth_token else None,
)
# See more about loading any type of standard or custom dataset at
# https://huggingface.co/docs/datasets/loading_datasets.html.
@@ -372,12 +380,21 @@ def main():
# Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(
model_args.model_name_or_path, num_labels=num_labels, finetuning_task=data_args.task_name
model_args.model_name_or_path,
num_labels=num_labels,
finetuning_task=data_args.task_name,
use_auth_token=True if model_args.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, use_fast=not model_args.use_slow_tokenizer
model_args.model_name_or_path,
use_fast=not model_args.use_slow_tokenizer,
use_auth_token=True if model_args.use_auth_token else None,
)
model = FlaxAutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path,
config=config,
use_auth_token=True if model_args.use_auth_token else None,
)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_args.model_name_or_path, config=config)
# Preprocessing the datasets
if data_args.task_name is not None: