From 319d840b46fd3a13e0434de9de69bd74a2f22f43 Mon Sep 17 00:00:00 2001 From: Stefan Schweter Date: Fri, 27 Aug 2021 11:35:45 +0200 Subject: [PATCH] examples: add keep_linebreaks option to CLM examples (#13150) * examples: add keep_linebreaks option to text dataset loader for all CLM examples * examples: introduce new keep_linebreaks option as data argument in CLM examples --- examples/flax/language-modeling/run_clm_flax.py | 5 +++++ examples/pytorch/language-modeling/run_clm.py | 5 +++++ examples/pytorch/language-modeling/run_clm_no_trainer.py | 5 +++++ examples/tensorflow/language-modeling/run_clm.py | 5 ++++- 4 files changed, 19 insertions(+), 1 deletion(-) diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index d12e9c11f8..f22da85934 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -156,6 +156,9 @@ class DataTrainingArguments: default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) + keep_linebreaks: bool = field( + default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."} + ) def __post_init__(self): if self.dataset_name is None and self.train_file is None and self.validation_file is None: @@ -314,12 +317,14 @@ def main(): if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( extension, + keep_linebreaks=data_args.keep_linebreaks, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) dataset["train"] = load_dataset( extension, + keep_linebreaks=data_args.keep_linebreaks, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 01f4780a0a..c5e872a4ec 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -172,6 +172,9 @@ class DataTrainingArguments: default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) + keep_linebreaks: bool = field( + default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."} + ) def __post_init__(self): if self.dataset_name is None and self.train_file is None and self.validation_file is None: @@ -282,12 +285,14 @@ def main(): if "validation" not in raw_datasets.keys(): raw_datasets["validation"] = load_dataset( extension, + keep_linebreaks=data_args.keep_linebreaks, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) raw_datasets["train"] = load_dataset( extension, + keep_linebreaks=data_args.keep_linebreaks, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py index 010d99a4b9..723c64c3d0 100755 --- a/examples/pytorch/language-modeling/run_clm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py @@ -173,6 +173,9 @@ def parse_args(): parser.add_argument( "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" ) + parser.add_argument( + "--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using CSV/JSON/TXT files." + ) args = parser.parse_args() @@ -257,11 +260,13 @@ def main(): if "validation" not in raw_datasets.keys(): raw_datasets["validation"] = load_dataset( extension, + keep_linebreaks=not args.no_keep_linebreaks, data_files=data_files, split=f"train[:{args.validation_split_percentage}%]", ) raw_datasets["train"] = load_dataset( extension, + keep_linebreaks=not args.no_keep_linebreaks, data_files=data_files, split=f"train[{args.validation_split_percentage}%:]", ) diff --git a/examples/tensorflow/language-modeling/run_clm.py b/examples/tensorflow/language-modeling/run_clm.py index 97ac093bb8..c9e5bc0536 100755 --- a/examples/tensorflow/language-modeling/run_clm.py +++ b/examples/tensorflow/language-modeling/run_clm.py @@ -186,6 +186,9 @@ class DataTrainingArguments: "value if set." }, ) + keep_linebreaks: bool = field( + default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."} + ) def __post_init__(self): if self.dataset_name is None and self.train_file is None and self.validation_file is None: @@ -325,7 +328,7 @@ def main(): extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" - raw_datasets = load_dataset(extension, data_files=data_files) + raw_datasets = load_dataset(extension, keep_linebreaks=data_args.keep_linebreaks, data_files=data_files) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # endregion