[Examples] Added context manager to datasets map (#12367)
* added cotext manager to datasets map * fixed style and spaces * fixed warning of deprecation * changed desc
This commit is contained in:
@@ -400,12 +400,13 @@ def main():
|
||||
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
|
||||
return result
|
||||
|
||||
raw_datasets = raw_datasets.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
raw_datasets = raw_datasets.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
if training_args.do_train:
|
||||
if "train" not in raw_datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
@@ -526,7 +527,7 @@ def main():
|
||||
|
||||
for predict_dataset, task in zip(predict_datasets, tasks):
|
||||
# Removing the `label` columns because it contains -1 and Trainer won't like that.
|
||||
predict_dataset.remove_columns_("label")
|
||||
predict_dataset = predict_dataset.remove_columns("label")
|
||||
predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
|
||||
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user