[run_(clm|mlm).py examples] add streaming dataset support (#21343)
* [run_clm example] add streaming dataset support * unrefactor kwargs * fix * fix * require datasets>=2.0.0 * port to mlm
This commit is contained in:
@@ -174,6 +174,9 @@ concatenates all texts and then splits them in blocks of the same length).
|
|||||||
**Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make
|
**Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make
|
||||||
sure all your batches have the same length.
|
sure all your batches have the same length.
|
||||||
|
|
||||||
|
## Streaming
|
||||||
|
|
||||||
|
To use the streaming dataset mode which can be very useful for large datasets, add `--streaming` to the command line. This is currently supported by `run_mlm.py` and `run_clm.py`.
|
||||||
|
|
||||||
## Creating a model on the fly
|
## Creating a model on the fly
|
||||||
|
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
|
||||||
block_size: Optional[int] = field(
|
block_size: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
@@ -202,6 +202,9 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
if self.streaming:
|
||||||
|
require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
|
||||||
|
|
||||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||||
else:
|
else:
|
||||||
@@ -285,6 +288,7 @@ def main():
|
|||||||
data_args.dataset_config_name,
|
data_args.dataset_config_name,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
streaming=data_args.streaming,
|
||||||
)
|
)
|
||||||
if "validation" not in raw_datasets.keys():
|
if "validation" not in raw_datasets.keys():
|
||||||
raw_datasets["validation"] = load_dataset(
|
raw_datasets["validation"] = load_dataset(
|
||||||
@@ -293,6 +297,7 @@ def main():
|
|||||||
split=f"train[:{data_args.validation_split_percentage}%]",
|
split=f"train[:{data_args.validation_split_percentage}%]",
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
streaming=data_args.streaming,
|
||||||
)
|
)
|
||||||
raw_datasets["train"] = load_dataset(
|
raw_datasets["train"] = load_dataset(
|
||||||
data_args.dataset_name,
|
data_args.dataset_name,
|
||||||
@@ -300,6 +305,7 @@ def main():
|
|||||||
split=f"train[{data_args.validation_split_percentage}%:]",
|
split=f"train[{data_args.validation_split_percentage}%:]",
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
streaming=data_args.streaming,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data_files = {}
|
data_files = {}
|
||||||
@@ -413,7 +419,13 @@ def main():
|
|||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# First we tokenize all the texts.
|
# First we tokenize all the texts.
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
|
if data_args.streaming:
|
||||||
|
column_names = raw_datasets["train"].features.keys()
|
||||||
|
else:
|
||||||
column_names = raw_datasets["train"].column_names
|
column_names = raw_datasets["train"].column_names
|
||||||
|
else:
|
||||||
|
if data_args.streaming:
|
||||||
|
column_names = raw_datasets["validation"].features.keys()
|
||||||
else:
|
else:
|
||||||
column_names = raw_datasets["validation"].column_names
|
column_names = raw_datasets["validation"].column_names
|
||||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||||
@@ -433,6 +445,7 @@ def main():
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
with training_args.main_process_first(desc="dataset map tokenization"):
|
with training_args.main_process_first(desc="dataset map tokenization"):
|
||||||
|
if not data_args.streaming:
|
||||||
tokenized_datasets = raw_datasets.map(
|
tokenized_datasets = raw_datasets.map(
|
||||||
tokenize_function,
|
tokenize_function,
|
||||||
batched=True,
|
batched=True,
|
||||||
@@ -441,6 +454,12 @@ def main():
|
|||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
desc="Running tokenizer on dataset",
|
desc="Running tokenizer on dataset",
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
tokenized_datasets = raw_datasets.map(
|
||||||
|
tokenize_function,
|
||||||
|
batched=True,
|
||||||
|
remove_columns=column_names,
|
||||||
|
)
|
||||||
|
|
||||||
if data_args.block_size is None:
|
if data_args.block_size is None:
|
||||||
block_size = tokenizer.model_max_length
|
block_size = tokenizer.model_max_length
|
||||||
@@ -483,6 +502,7 @@ def main():
|
|||||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
||||||
|
|
||||||
with training_args.main_process_first(desc="grouping texts together"):
|
with training_args.main_process_first(desc="grouping texts together"):
|
||||||
|
if not data_args.streaming:
|
||||||
lm_datasets = tokenized_datasets.map(
|
lm_datasets = tokenized_datasets.map(
|
||||||
group_texts,
|
group_texts,
|
||||||
batched=True,
|
batched=True,
|
||||||
@@ -490,6 +510,11 @@ def main():
|
|||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
desc=f"Grouping texts in chunks of {block_size}",
|
desc=f"Grouping texts in chunks of {block_size}",
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
lm_datasets = tokenized_datasets.map(
|
||||||
|
group_texts,
|
||||||
|
batched=True,
|
||||||
|
)
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if "train" not in tokenized_datasets:
|
if "train" not in tokenized_datasets:
|
||||||
|
|||||||
@@ -197,8 +197,12 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
if self.streaming:
|
||||||
|
require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
|
||||||
|
|
||||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||||
else:
|
else:
|
||||||
@@ -285,6 +289,7 @@ def main():
|
|||||||
data_args.dataset_config_name,
|
data_args.dataset_config_name,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
streaming=data_args.streaming,
|
||||||
)
|
)
|
||||||
if "validation" not in raw_datasets.keys():
|
if "validation" not in raw_datasets.keys():
|
||||||
raw_datasets["validation"] = load_dataset(
|
raw_datasets["validation"] = load_dataset(
|
||||||
@@ -293,6 +298,7 @@ def main():
|
|||||||
split=f"train[:{data_args.validation_split_percentage}%]",
|
split=f"train[:{data_args.validation_split_percentage}%]",
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
streaming=data_args.streaming,
|
||||||
)
|
)
|
||||||
raw_datasets["train"] = load_dataset(
|
raw_datasets["train"] = load_dataset(
|
||||||
data_args.dataset_name,
|
data_args.dataset_name,
|
||||||
@@ -300,6 +306,7 @@ def main():
|
|||||||
split=f"train[{data_args.validation_split_percentage}%:]",
|
split=f"train[{data_args.validation_split_percentage}%:]",
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
streaming=data_args.streaming,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data_files = {}
|
data_files = {}
|
||||||
@@ -398,7 +405,13 @@ def main():
|
|||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# First we tokenize all the texts.
|
# First we tokenize all the texts.
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
|
if data_args.streaming:
|
||||||
|
column_names = raw_datasets["train"].features.keys()
|
||||||
|
else:
|
||||||
column_names = raw_datasets["train"].column_names
|
column_names = raw_datasets["train"].column_names
|
||||||
|
else:
|
||||||
|
if data_args.streaming:
|
||||||
|
column_names = raw_datasets["validation"].features.keys()
|
||||||
else:
|
else:
|
||||||
column_names = raw_datasets["validation"].column_names
|
column_names = raw_datasets["validation"].column_names
|
||||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||||
@@ -439,6 +452,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
with training_args.main_process_first(desc="dataset map tokenization"):
|
with training_args.main_process_first(desc="dataset map tokenization"):
|
||||||
|
if not data_args.streaming:
|
||||||
tokenized_datasets = raw_datasets.map(
|
tokenized_datasets = raw_datasets.map(
|
||||||
tokenize_function,
|
tokenize_function,
|
||||||
batched=True,
|
batched=True,
|
||||||
@@ -447,6 +461,12 @@ def main():
|
|||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
desc="Running tokenizer on dataset line_by_line",
|
desc="Running tokenizer on dataset line_by_line",
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
tokenized_datasets = raw_datasets.map(
|
||||||
|
tokenize_function,
|
||||||
|
batched=True,
|
||||||
|
remove_columns=[text_column_name],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
||||||
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
|
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
|
||||||
@@ -455,6 +475,7 @@ def main():
|
|||||||
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
|
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
|
||||||
|
|
||||||
with training_args.main_process_first(desc="dataset map tokenization"):
|
with training_args.main_process_first(desc="dataset map tokenization"):
|
||||||
|
if not data_args.streaming:
|
||||||
tokenized_datasets = raw_datasets.map(
|
tokenized_datasets = raw_datasets.map(
|
||||||
tokenize_function,
|
tokenize_function,
|
||||||
batched=True,
|
batched=True,
|
||||||
@@ -463,6 +484,12 @@ def main():
|
|||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
desc="Running tokenizer on every text in dataset",
|
desc="Running tokenizer on every text in dataset",
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
tokenized_datasets = raw_datasets.map(
|
||||||
|
tokenize_function,
|
||||||
|
batched=True,
|
||||||
|
remove_columns=column_names,
|
||||||
|
)
|
||||||
|
|
||||||
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
||||||
# max_seq_length.
|
# max_seq_length.
|
||||||
@@ -489,6 +516,7 @@ def main():
|
|||||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
||||||
|
|
||||||
with training_args.main_process_first(desc="grouping texts together"):
|
with training_args.main_process_first(desc="grouping texts together"):
|
||||||
|
if not data_args.streaming:
|
||||||
tokenized_datasets = tokenized_datasets.map(
|
tokenized_datasets = tokenized_datasets.map(
|
||||||
group_texts,
|
group_texts,
|
||||||
batched=True,
|
batched=True,
|
||||||
@@ -496,6 +524,11 @@ def main():
|
|||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
desc=f"Grouping texts in chunks of {max_seq_length}",
|
desc=f"Grouping texts in chunks of {max_seq_length}",
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
tokenized_datasets = tokenized_datasets.map(
|
||||||
|
group_texts,
|
||||||
|
batched=True,
|
||||||
|
)
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if "train" not in tokenized_datasets:
|
if "train" not in tokenized_datasets:
|
||||||
|
|||||||
Reference in New Issue
Block a user