[Examples] Added context manager to datasets map (#12367)

* added cotext manager to datasets map

* fixed style and spaces

* fixed warning of deprecation

* changed desc
This commit is contained in:
Bhadresh Savani
2021-06-28 21:44:00 +05:30
committed by GitHub
parent d25ad34c82
commit 04dbea31a9
12 changed files with 242 additions and 213 deletions

View File

@@ -370,13 +370,14 @@ def main():
# Select Sample from Dataset
train_dataset = train_dataset.select(range(data_args.max_train_samples))
# tokenize train dataset in batch
train_dataset = train_dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[text_column_name],
load_from_cache_file=not data_args.overwrite_cache,
)
with training_args.main_process_first(desc="train dataset map tokenization"):
train_dataset = train_dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[text_column_name],
load_from_cache_file=not data_args.overwrite_cache,
)
if training_args.do_eval:
if "validation" not in raw_datasets:
@@ -386,13 +387,14 @@ def main():
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
# tokenize validation dataset
eval_dataset = eval_dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[text_column_name],
load_from_cache_file=not data_args.overwrite_cache,
)
with training_args.main_process_first(desc="validation dataset map tokenization"):
eval_dataset = eval_dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[text_column_name],
load_from_cache_file=not data_args.overwrite_cache,
)
if training_args.do_predict:
if "test" not in raw_datasets:
@@ -402,13 +404,14 @@ def main():
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
# tokenize predict dataset
predict_dataset = predict_dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[text_column_name],
load_from_cache_file=not data_args.overwrite_cache,
)
with training_args.main_process_first(desc="prediction dataset map tokenization"):
predict_dataset = predict_dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[text_column_name],
load_from_cache_file=not data_args.overwrite_cache,
)
# Data collator
data_collator=default_data_collator if not training_args.fp16 else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)