From 66197adc9880da10d690a09b6dc3fc371c43abdc Mon Sep 17 00:00:00 2001 From: fgaim Date: Tue, 20 Jul 2021 13:38:25 +0200 Subject: [PATCH] Flax MLM: Allow validation split when loading dataset from local file (#12689) * Allow validation split when loading dataset from local file * Flax clm & t5, enable validation split for datasets loaded from local file --- examples/flax/language-modeling/run_clm_flax.py | 14 ++++++++++++++ examples/flax/language-modeling/run_mlm_flax.py | 14 ++++++++++++++ examples/flax/language-modeling/run_t5_mlm_flax.py | 13 +++++++++++++ 3 files changed, 41 insertions(+) diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index bddd5b9905..23e550f51d 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -307,6 +307,20 @@ def main(): if extension == "txt": extension = "text" dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) + + if "validation" not in datasets.keys(): + datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + ) + datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 4282560dac..85e577d1bf 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -344,6 +344,20 @@ if __name__ == "__main__": if extension == "txt": extension = "text" datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) + + if "validation" not in datasets.keys(): + datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + ) + datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index c206d76bec..e50381e4d9 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -471,6 +471,19 @@ if __name__ == "__main__": extension = "text" datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) + if "validation" not in datasets.keys(): + datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + ) + datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html.