examples: only use keep_linebreaks when reading TXT files (#13320)
* examples: only use keep_linebreaks when reading TXT files for all CLM examples * examples: only use keep_linebreaks when reading TXT files for all CLM examples * examples: only use keep_linebreaks when reading TXT files for all CLM examples
This commit is contained in:
@@ -157,7 +157,7 @@ class DataTrainingArguments:
|
|||||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||||
)
|
)
|
||||||
keep_linebreaks: bool = field(
|
keep_linebreaks: bool = field(
|
||||||
default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."}
|
default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -305,6 +305,7 @@ def main():
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data_files = {}
|
data_files = {}
|
||||||
|
dataset_args = {}
|
||||||
if data_args.train_file is not None:
|
if data_args.train_file is not None:
|
||||||
data_files["train"] = data_args.train_file
|
data_files["train"] = data_args.train_file
|
||||||
if data_args.validation_file is not None:
|
if data_args.validation_file is not None:
|
||||||
@@ -312,22 +313,23 @@ def main():
|
|||||||
extension = data_args.train_file.split(".")[-1]
|
extension = data_args.train_file.split(".")[-1]
|
||||||
if extension == "txt":
|
if extension == "txt":
|
||||||
extension = "text"
|
extension = "text"
|
||||||
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
|
||||||
|
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args)
|
||||||
|
|
||||||
if "validation" not in dataset.keys():
|
if "validation" not in dataset.keys():
|
||||||
dataset["validation"] = load_dataset(
|
dataset["validation"] = load_dataset(
|
||||||
extension,
|
extension,
|
||||||
keep_linebreaks=data_args.keep_linebreaks,
|
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
split=f"train[:{data_args.validation_split_percentage}%]",
|
split=f"train[:{data_args.validation_split_percentage}%]",
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
|
**dataset_args,
|
||||||
)
|
)
|
||||||
dataset["train"] = load_dataset(
|
dataset["train"] = load_dataset(
|
||||||
extension,
|
extension,
|
||||||
keep_linebreaks=data_args.keep_linebreaks,
|
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
split=f"train[{data_args.validation_split_percentage}%:]",
|
split=f"train[{data_args.validation_split_percentage}%:]",
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
|
**dataset_args,
|
||||||
)
|
)
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ class DataTrainingArguments:
|
|||||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||||
)
|
)
|
||||||
keep_linebreaks: bool = field(
|
keep_linebreaks: bool = field(
|
||||||
default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."}
|
default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -269,6 +269,7 @@ def main():
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data_files = {}
|
data_files = {}
|
||||||
|
dataset_args = {}
|
||||||
if data_args.train_file is not None:
|
if data_args.train_file is not None:
|
||||||
data_files["train"] = data_args.train_file
|
data_files["train"] = data_args.train_file
|
||||||
if data_args.validation_file is not None:
|
if data_args.validation_file is not None:
|
||||||
@@ -280,22 +281,23 @@ def main():
|
|||||||
)
|
)
|
||||||
if extension == "txt":
|
if extension == "txt":
|
||||||
extension = "text"
|
extension = "text"
|
||||||
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
|
||||||
|
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args)
|
||||||
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
||||||
if "validation" not in raw_datasets.keys():
|
if "validation" not in raw_datasets.keys():
|
||||||
raw_datasets["validation"] = load_dataset(
|
raw_datasets["validation"] = load_dataset(
|
||||||
extension,
|
extension,
|
||||||
keep_linebreaks=data_args.keep_linebreaks,
|
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
split=f"train[:{data_args.validation_split_percentage}%]",
|
split=f"train[:{data_args.validation_split_percentage}%]",
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
|
**dataset_args,
|
||||||
)
|
)
|
||||||
raw_datasets["train"] = load_dataset(
|
raw_datasets["train"] = load_dataset(
|
||||||
extension,
|
extension,
|
||||||
keep_linebreaks=data_args.keep_linebreaks,
|
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
split=f"train[{data_args.validation_split_percentage}%:]",
|
split=f"train[{data_args.validation_split_percentage}%:]",
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
|
**dataset_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ def parse_args():
|
|||||||
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using CSV/JSON/TXT files."
|
"--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files."
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -248,6 +248,7 @@ def main():
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data_files = {}
|
data_files = {}
|
||||||
|
dataset_args = {}
|
||||||
if args.train_file is not None:
|
if args.train_file is not None:
|
||||||
data_files["train"] = args.train_file
|
data_files["train"] = args.train_file
|
||||||
if args.validation_file is not None:
|
if args.validation_file is not None:
|
||||||
@@ -255,20 +256,21 @@ def main():
|
|||||||
extension = args.train_file.split(".")[-1]
|
extension = args.train_file.split(".")[-1]
|
||||||
if extension == "txt":
|
if extension == "txt":
|
||||||
extension = "text"
|
extension = "text"
|
||||||
raw_datasets = load_dataset(extension, data_files=data_files)
|
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
|
||||||
|
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
|
||||||
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
||||||
if "validation" not in raw_datasets.keys():
|
if "validation" not in raw_datasets.keys():
|
||||||
raw_datasets["validation"] = load_dataset(
|
raw_datasets["validation"] = load_dataset(
|
||||||
extension,
|
extension,
|
||||||
keep_linebreaks=not args.no_keep_linebreaks,
|
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
split=f"train[:{args.validation_split_percentage}%]",
|
split=f"train[:{args.validation_split_percentage}%]",
|
||||||
|
**dataset_args,
|
||||||
)
|
)
|
||||||
raw_datasets["train"] = load_dataset(
|
raw_datasets["train"] = load_dataset(
|
||||||
extension,
|
extension,
|
||||||
keep_linebreaks=not args.no_keep_linebreaks,
|
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
split=f"train[{args.validation_split_percentage}%:]",
|
split=f"train[{args.validation_split_percentage}%:]",
|
||||||
|
**dataset_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
|
|||||||
@@ -187,7 +187,7 @@ class DataTrainingArguments:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
keep_linebreaks: bool = field(
|
keep_linebreaks: bool = field(
|
||||||
default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."}
|
default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -321,6 +321,7 @@ def main():
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data_files = {}
|
data_files = {}
|
||||||
|
dataset_args = {}
|
||||||
if data_args.train_file is not None:
|
if data_args.train_file is not None:
|
||||||
data_files["train"] = data_args.train_file
|
data_files["train"] = data_args.train_file
|
||||||
if data_args.validation_file is not None:
|
if data_args.validation_file is not None:
|
||||||
@@ -328,7 +329,8 @@ def main():
|
|||||||
extension = data_args.train_file.split(".")[-1]
|
extension = data_args.train_file.split(".")[-1]
|
||||||
if extension == "txt":
|
if extension == "txt":
|
||||||
extension = "text"
|
extension = "text"
|
||||||
raw_datasets = load_dataset(extension, keep_linebreaks=data_args.keep_linebreaks, data_files=data_files)
|
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
|
||||||
|
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
# endregion
|
# endregion
|
||||||
|
|||||||
Reference in New Issue
Block a user