Flax MLM: Allow validation split when loading dataset from local file (#12689)
* Allow validation split when loading dataset from local file * Flax clm & t5, enable validation split for datasets loaded from local file
This commit is contained in:
@@ -307,6 +307,20 @@ def main():
|
|||||||
if extension == "txt":
|
if extension == "txt":
|
||||||
extension = "text"
|
extension = "text"
|
||||||
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
||||||
|
|
||||||
|
if "validation" not in datasets.keys():
|
||||||
|
datasets["validation"] = load_dataset(
|
||||||
|
extension,
|
||||||
|
data_files=data_files,
|
||||||
|
split=f"train[:{data_args.validation_split_percentage}%]",
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
|
datasets["train"] = load_dataset(
|
||||||
|
extension,
|
||||||
|
data_files=data_files,
|
||||||
|
split=f"train[{data_args.validation_split_percentage}%:]",
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
|
|
||||||
|
|||||||
@@ -344,6 +344,20 @@ if __name__ == "__main__":
|
|||||||
if extension == "txt":
|
if extension == "txt":
|
||||||
extension = "text"
|
extension = "text"
|
||||||
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
||||||
|
|
||||||
|
if "validation" not in datasets.keys():
|
||||||
|
datasets["validation"] = load_dataset(
|
||||||
|
extension,
|
||||||
|
data_files=data_files,
|
||||||
|
split=f"train[:{data_args.validation_split_percentage}%]",
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
|
datasets["train"] = load_dataset(
|
||||||
|
extension,
|
||||||
|
data_files=data_files,
|
||||||
|
split=f"train[{data_args.validation_split_percentage}%:]",
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
|
|
||||||
|
|||||||
@@ -471,6 +471,19 @@ if __name__ == "__main__":
|
|||||||
extension = "text"
|
extension = "text"
|
||||||
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
||||||
|
|
||||||
|
if "validation" not in datasets.keys():
|
||||||
|
datasets["validation"] = load_dataset(
|
||||||
|
extension,
|
||||||
|
data_files=data_files,
|
||||||
|
split=f"train[:{data_args.validation_split_percentage}%]",
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
|
datasets["train"] = load_dataset(
|
||||||
|
extension,
|
||||||
|
data_files=data_files,
|
||||||
|
split=f"train[{data_args.validation_split_percentage}%:]",
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
)
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user