[run_(clm|mlm).py examples] add streaming dataset support (#21343)
* [run_clm example] add streaming dataset support * unrefactor kwargs * fix * fix * require datasets>=2.0.0 * port to mlm
This commit is contained in:
@@ -173,7 +173,7 @@ class DataTrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
|
||||
block_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@@ -202,6 +202,9 @@ class DataTrainingArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.streaming:
|
||||
require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
|
||||
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
else:
|
||||
@@ -285,6 +288,7 @@ def main():
|
||||
data_args.dataset_config_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
if "validation" not in raw_datasets.keys():
|
||||
raw_datasets["validation"] = load_dataset(
|
||||
@@ -293,6 +297,7 @@ def main():
|
||||
split=f"train[:{data_args.validation_split_percentage}%]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
raw_datasets["train"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
@@ -300,6 +305,7 @@ def main():
|
||||
split=f"train[{data_args.validation_split_percentage}%:]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
@@ -413,9 +419,15 @@ def main():
|
||||
# Preprocessing the datasets.
|
||||
# First we tokenize all the texts.
|
||||
if training_args.do_train:
|
||||
column_names = raw_datasets["train"].column_names
|
||||
if data_args.streaming:
|
||||
column_names = raw_datasets["train"].features.keys()
|
||||
else:
|
||||
column_names = raw_datasets["train"].column_names
|
||||
else:
|
||||
column_names = raw_datasets["validation"].column_names
|
||||
if data_args.streaming:
|
||||
column_names = raw_datasets["validation"].features.keys()
|
||||
else:
|
||||
column_names = raw_datasets["validation"].column_names
|
||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||
|
||||
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
|
||||
@@ -433,14 +445,21 @@ def main():
|
||||
return output
|
||||
|
||||
with training_args.main_process_first(desc="dataset map tokenization"):
|
||||
tokenized_datasets = raw_datasets.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
if not data_args.streaming:
|
||||
tokenized_datasets = raw_datasets.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
else:
|
||||
tokenized_datasets = raw_datasets.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
)
|
||||
|
||||
if data_args.block_size is None:
|
||||
block_size = tokenizer.model_max_length
|
||||
@@ -483,13 +502,19 @@ def main():
|
||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
||||
|
||||
with training_args.main_process_first(desc="grouping texts together"):
|
||||
lm_datasets = tokenized_datasets.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc=f"Grouping texts in chunks of {block_size}",
|
||||
)
|
||||
if not data_args.streaming:
|
||||
lm_datasets = tokenized_datasets.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc=f"Grouping texts in chunks of {block_size}",
|
||||
)
|
||||
else:
|
||||
lm_datasets = tokenized_datasets.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
)
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in tokenized_datasets:
|
||||
|
||||
Reference in New Issue
Block a user