Validation split added: custom data files @sgugger, @patil-suraj (#12407)
* Validation split added: custom data files Validation split added in case of no validation file and loading custom data * Updated documentation with custom file usage Updated documentation with custom file usage * Update README.md * Update README.md * Update README.md * Made some suggested stylistic changes * Used logger instead of print. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Made similar changes to add validation split In case of a missing validation file, a validation split will be used now. * max_train_samples to be used for training only max_train_samples got misplaced, now corrected so that it is applied on training data only, not whole data. * styled * changed ordering * Improved language of documentation Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Improved language of documentation Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fixed styling issue * Update run_mlm.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
f929462b25
commit
d5b8fe3b90
@@ -49,6 +49,14 @@ python run_mlm.py \
|
|||||||
--dataset_config_name wikitext-103-raw-v1
|
--dataset_config_name wikitext-103-raw-v1
|
||||||
```
|
```
|
||||||
|
|
||||||
|
When using a custom dataset, the validation file can be separately passed as an input argument. Otherwise some split (customizable) of training data is used as validation.
|
||||||
|
```
|
||||||
|
python run_mlm.py \
|
||||||
|
--model_name_or_path distilbert-base-cased \
|
||||||
|
--output_dir output \
|
||||||
|
--train_file train_file_path
|
||||||
|
```
|
||||||
|
|
||||||
## run_clm.py
|
## run_clm.py
|
||||||
|
|
||||||
This script trains a causal language model.
|
This script trains a causal language model.
|
||||||
@@ -61,3 +69,12 @@ python run_clm.py \
|
|||||||
--dataset_name wikitext \
|
--dataset_name wikitext \
|
||||||
--dataset_config_name wikitext-103-raw-v1
|
--dataset_config_name wikitext-103-raw-v1
|
||||||
```
|
```
|
||||||
|
|
||||||
|
When using a custom dataset, the validation file can be separately passed as an input argument. Otherwise some split (customizable) of training data is used as validation.
|
||||||
|
|
||||||
|
```
|
||||||
|
python run_clm.py \
|
||||||
|
--model_name_or_path distilgpt2 \
|
||||||
|
--output_dir output \
|
||||||
|
--train_file train_file_path
|
||||||
|
```
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ import datasets
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -429,7 +430,18 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = lm_datasets["train"]
|
train_dataset = lm_datasets["train"]
|
||||||
|
if data_args.validation_file is not None:
|
||||||
eval_dataset = lm_datasets["validation"]
|
eval_dataset = lm_datasets["validation"]
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Validation file not found: using {data_args.validation_split_percentage}% of the dataset as validation as provided in data_args"
|
||||||
|
)
|
||||||
|
train_indices, val_indices = train_test_split(
|
||||||
|
list(range(len(train_dataset))), test_size=data_args.validation_split_percentage
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_dataset = train_dataset.select(val_indices)
|
||||||
|
train_dataset = train_dataset.select(train_indices)
|
||||||
|
|
||||||
if data_args.max_train_samples is not None:
|
if data_args.max_train_samples is not None:
|
||||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ import datasets
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -363,6 +364,7 @@ def main():
|
|||||||
if extension == "txt":
|
if extension == "txt":
|
||||||
extension = "text"
|
extension = "text"
|
||||||
raw_datasets = load_dataset(extension, data_files=data_files)
|
raw_datasets = load_dataset(extension, data_files=data_files)
|
||||||
|
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
# endregion
|
# endregion
|
||||||
@@ -488,9 +490,22 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = tokenized_datasets["train"]
|
train_dataset = tokenized_datasets["train"]
|
||||||
|
|
||||||
|
if data_args.validation_file is not None:
|
||||||
|
eval_dataset = tokenized_datasets["validation"]
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Validation file not found: using {data_args.validation_split_percentage}% of the dataset as validation as provided in data_args"
|
||||||
|
)
|
||||||
|
train_indices, val_indices = train_test_split(
|
||||||
|
list(range(len(train_dataset))), test_size=data_args.validation_split_percentage
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_dataset = train_dataset.select(val_indices)
|
||||||
|
train_dataset = train_dataset.select(train_indices)
|
||||||
|
|
||||||
if data_args.max_train_samples is not None:
|
if data_args.max_train_samples is not None:
|
||||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||||
eval_dataset = tokenized_datasets["validation"]
|
|
||||||
if data_args.max_eval_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user