updated example template (#12365)
This commit is contained in:
@@ -27,6 +27,7 @@ import sys
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import datasets
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
@@ -226,16 +227,19 @@ def main():
|
|||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
)
|
)
|
||||||
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
|
||||||
|
log_level = training_args.get_process_log_level()
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
datasets.utils.logging.set_verbosity(log_level)
|
||||||
|
transformers.utils.logging.set_verbosity(log_level)
|
||||||
|
transformers.utils.logging.enable_default_handler()
|
||||||
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
)
|
)
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
|
||||||
if training_args.should_log:
|
|
||||||
transformers.utils.logging.set_verbosity_info()
|
|
||||||
logger.info(f"Training/evaluation parameters {training_args}")
|
logger.info(f"Training/evaluation parameters {training_args}")
|
||||||
|
|
||||||
# Set seed before initializing model.
|
# Set seed before initializing model.
|
||||||
@@ -252,7 +256,7 @@ def main():
|
|||||||
# download the dataset.
|
# download the dataset.
|
||||||
if data_args.dataset_name is not None:
|
if data_args.dataset_name is not None:
|
||||||
# Downloading and loading a dataset from the hub.
|
# Downloading and loading a dataset from the hub.
|
||||||
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
|
raw_datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
|
||||||
else:
|
else:
|
||||||
data_files = {}
|
data_files = {}
|
||||||
if data_args.train_file is not None:
|
if data_args.train_file is not None:
|
||||||
@@ -266,7 +270,7 @@ def main():
|
|||||||
extension = data_args.test_file.split(".")[-1]
|
extension = data_args.test_file.split(".")[-1]
|
||||||
if extension == "txt":
|
if extension == "txt":
|
||||||
extension = "text"
|
extension = "text"
|
||||||
datasets = load_dataset(extension, data_files=data_files)
|
raw_datasets = load_dataset(extension, data_files=data_files)
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
|
|
||||||
@@ -348,20 +352,20 @@ def main():
|
|||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# First we tokenize all the texts.
|
# First we tokenize all the texts.
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
column_names = datasets["train"].column_names
|
column_names = raw_datasets["train"].column_names
|
||||||
elif training_args.do_eval:
|
elif training_args.do_eval:
|
||||||
column_names = datasets["validation"].column_names
|
column_names = raw_datasets["validation"].column_names
|
||||||
elif training_args.do_predict:
|
elif training_args.do_predict:
|
||||||
column_names = datasets["test"].column_names
|
column_names = raw_datasets["test"].column_names
|
||||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||||
|
|
||||||
def tokenize_function(examples):
|
def tokenize_function(examples):
|
||||||
return tokenizer(examples[text_column_name], padding="max_length", truncation=True)
|
return tokenizer(examples[text_column_name], padding="max_length", truncation=True)
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if "train" not in datasets:
|
if "train" not in raw_datasets:
|
||||||
raise ValueError("--do_train requires a train dataset")
|
raise ValueError("--do_train requires a train dataset")
|
||||||
train_dataset = datasets["train"]
|
train_dataset = raw_datasets["train"]
|
||||||
if data_args.max_train_samples is not None:
|
if data_args.max_train_samples is not None:
|
||||||
# Select Sample from Dataset
|
# Select Sample from Dataset
|
||||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||||
@@ -375,9 +379,9 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
if "validation" not in datasets:
|
if "validation" not in raw_datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_dataset = datasets["validation"]
|
eval_dataset = raw_datasets["validation"]
|
||||||
# Selecting samples from dataset
|
# Selecting samples from dataset
|
||||||
if data_args.max_eval_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
@@ -391,9 +395,9 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
if "test" not in datasets:
|
if "test" not in raw_datasets:
|
||||||
raise ValueError("--do_predict requires a test dataset")
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
predict_dataset = datasets["test"]
|
predict_dataset = raw_datasets["test"]
|
||||||
# Selecting samples from dataset
|
# Selecting samples from dataset
|
||||||
if data_args.max_predict_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||||
@@ -754,7 +758,7 @@ def main():
|
|||||||
|
|
||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# First we tokenize all the texts.
|
# First we tokenize all the texts.
|
||||||
column_names = datasets["train"].column_names
|
column_names = raw_datasets["train"].column_names
|
||||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||||
|
|
||||||
padding = "max_length" if args.pad_to_max_length else False
|
padding = "max_length" if args.pad_to_max_length else False
|
||||||
|
|||||||
Reference in New Issue
Block a user