[Examples] Added context manager to datasets map (#12367)

* added cotext manager to datasets map

* fixed style and spaces

* fixed warning of deprecation

* changed desc
This commit is contained in:
Bhadresh Savani
2021-06-28 21:44:00 +05:30
committed by GitHub
parent d25ad34c82
commit 04dbea31a9
12 changed files with 242 additions and 213 deletions

View File

@@ -400,12 +400,13 @@ def main():
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
return result
raw_datasets = raw_datasets.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset",
)
with training_args.main_process_first(desc="dataset map pre-processing"):
raw_datasets = raw_datasets.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset",
)
if training_args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
@@ -526,7 +527,7 @@ def main():
for predict_dataset, task in zip(predict_datasets, tasks):
# Removing the `label` columns because it contains -1 and Trainer won't like that.
predict_dataset.remove_columns_("label")
predict_dataset = predict_dataset.remove_columns("label")
predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)

View File

@@ -280,12 +280,13 @@ def main():
if training_args.do_train:
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on train dataset",
)
with training_args.main_process_first(desc="train dataset map pre-processing"):
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on train dataset",
)
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
@@ -293,22 +294,24 @@ def main():
if training_args.do_eval:
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on validation dataset",
)
with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on validation dataset",
)
if training_args.do_predict:
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
predict_dataset = predict_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on prediction dataset",
)
with training_args.main_process_first(desc="prediction dataset map pre-processing"):
predict_dataset = predict_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on prediction dataset",
)
# Get the metric function
metric = load_metric("xnli")