update desc for map in all examples (#12226)
* update desc for map in all examples * added plm * suggestions
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
datasets >= 1.1.3
|
||||
datasets >= 1.8.0
|
||||
sentencepiece != 0.1.92
|
||||
protobuf
|
||||
rouge-score
|
||||
|
||||
@@ -43,10 +43,12 @@ from transformers import (
|
||||
from transformers.file_utils import is_offline_mode
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.8.0.dev0")
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -433,6 +435,7 @@ def main():
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on train dataset",
|
||||
)
|
||||
|
||||
if training_args.do_eval:
|
||||
@@ -448,6 +451,7 @@ def main():
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on validation dataset",
|
||||
)
|
||||
|
||||
if training_args.do_predict:
|
||||
@@ -463,6 +467,7 @@ def main():
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on prediction dataset",
|
||||
)
|
||||
|
||||
# Data collator
|
||||
|
||||
@@ -48,9 +48,12 @@ from transformers import (
|
||||
set_seed,
|
||||
)
|
||||
from transformers.file_utils import is_offline_mode
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
@@ -419,7 +422,11 @@ def main():
|
||||
return model_inputs
|
||||
|
||||
processed_datasets = raw_datasets.map(
|
||||
preprocess_function, batched=True, remove_columns=column_names, load_from_cache_file=not args.overwrite_cache
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
train_dataset = processed_datasets["train"]
|
||||
|
||||
Reference in New Issue
Block a user