From 98d88b23f54e5a23e741833f1e973fdf600cc2c5 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 30 Jan 2023 14:01:35 -0800 Subject: [PATCH] [`run_(clm|mlm).py` examples] add streaming dataset support (#21343) * [run_clm example] add streaming dataset support * unrefactor kwargs * fix * fix * require datasets>=2.0.0 * port to mlm --- examples/pytorch/language-modeling/README.md | 3 + examples/pytorch/language-modeling/run_clm.py | 61 ++++++++++---- examples/pytorch/language-modeling/run_mlm.py | 83 +++++++++++++------ 3 files changed, 104 insertions(+), 43 deletions(-) diff --git a/examples/pytorch/language-modeling/README.md b/examples/pytorch/language-modeling/README.md index 035a6dd696..ff504b5357 100644 --- a/examples/pytorch/language-modeling/README.md +++ b/examples/pytorch/language-modeling/README.md @@ -174,6 +174,9 @@ concatenates all texts and then splits them in blocks of the same length). **Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make sure all your batches have the same length. +## Streaming + +To use the streaming dataset mode which can be very useful for large datasets, add `--streaming` to the command line. This is currently supported by `run_mlm.py` and `run_clm.py`. ## Creating a model on the fly diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index c4d3008aa1..9a24c55456 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -173,7 +173,7 @@ class DataTrainingArguments: ) }, ) - + streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) block_size: Optional[int] = field( default=None, metadata={ @@ -202,6 +202,9 @@ class DataTrainingArguments: ) def __post_init__(self): + if self.streaming: + require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`") + if self.dataset_name is None and self.train_file is None and self.validation_file is None: raise ValueError("Need either a dataset name or a training/validation file.") else: @@ -285,6 +288,7 @@ def main(): data_args.dataset_config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, + streaming=data_args.streaming, ) if "validation" not in raw_datasets.keys(): raw_datasets["validation"] = load_dataset( @@ -293,6 +297,7 @@ def main(): split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, + streaming=data_args.streaming, ) raw_datasets["train"] = load_dataset( data_args.dataset_name, @@ -300,6 +305,7 @@ def main(): split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, + streaming=data_args.streaming, ) else: data_files = {} @@ -413,9 +419,15 @@ def main(): # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: - column_names = raw_datasets["train"].column_names + if data_args.streaming: + column_names = raw_datasets["train"].features.keys() + else: + column_names = raw_datasets["train"].column_names else: - column_names = raw_datasets["validation"].column_names + if data_args.streaming: + column_names = raw_datasets["validation"].features.keys() + else: + column_names = raw_datasets["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function @@ -433,14 +445,21 @@ def main(): return output with training_args.main_process_first(desc="dataset map tokenization"): - tokenized_datasets = raw_datasets.map( - tokenize_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on dataset", - ) + if not data_args.streaming: + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + else: + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + remove_columns=column_names, + ) if data_args.block_size is None: block_size = tokenizer.model_max_length @@ -483,13 +502,19 @@ def main(): # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map with training_args.main_process_first(desc="grouping texts together"): - lm_datasets = tokenized_datasets.map( - group_texts, - batched=True, - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=not data_args.overwrite_cache, - desc=f"Grouping texts in chunks of {block_size}", - ) + if not data_args.streaming: + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + desc=f"Grouping texts in chunks of {block_size}", + ) + else: + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + ) if training_args.do_train: if "train" not in tokenized_datasets: diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index a154b575e5..8cf76896d1 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -197,8 +197,12 @@ class DataTrainingArguments: ) }, ) + streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) def __post_init__(self): + if self.streaming: + require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`") + if self.dataset_name is None and self.train_file is None and self.validation_file is None: raise ValueError("Need either a dataset name or a training/validation file.") else: @@ -285,6 +289,7 @@ def main(): data_args.dataset_config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, + streaming=data_args.streaming, ) if "validation" not in raw_datasets.keys(): raw_datasets["validation"] = load_dataset( @@ -293,6 +298,7 @@ def main(): split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, + streaming=data_args.streaming, ) raw_datasets["train"] = load_dataset( data_args.dataset_name, @@ -300,6 +306,7 @@ def main(): split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, + streaming=data_args.streaming, ) else: data_files = {} @@ -398,9 +405,15 @@ def main(): # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: - column_names = raw_datasets["train"].column_names + if data_args.streaming: + column_names = raw_datasets["train"].features.keys() + else: + column_names = raw_datasets["train"].column_names else: - column_names = raw_datasets["validation"].column_names + if data_args.streaming: + column_names = raw_datasets["validation"].features.keys() + else: + column_names = raw_datasets["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] if data_args.max_seq_length is None: @@ -439,14 +452,21 @@ def main(): ) with training_args.main_process_first(desc="dataset map tokenization"): - tokenized_datasets = raw_datasets.map( - tokenize_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=[text_column_name], - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on dataset line_by_line", - ) + if not data_args.streaming: + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=[text_column_name], + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on dataset line_by_line", + ) + else: + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + remove_columns=[text_column_name], + ) else: # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more @@ -455,14 +475,21 @@ def main(): return tokenizer(examples[text_column_name], return_special_tokens_mask=True) with training_args.main_process_first(desc="dataset map tokenization"): - tokenized_datasets = raw_datasets.map( - tokenize_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on every text in dataset", - ) + if not data_args.streaming: + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on every text in dataset", + ) + else: + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + remove_columns=column_names, + ) # Main data processing function that will concatenate all texts from our dataset and generate chunks of # max_seq_length. @@ -489,13 +516,19 @@ def main(): # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map with training_args.main_process_first(desc="grouping texts together"): - tokenized_datasets = tokenized_datasets.map( - group_texts, - batched=True, - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=not data_args.overwrite_cache, - desc=f"Grouping texts in chunks of {max_seq_length}", - ) + if not data_args.streaming: + tokenized_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + desc=f"Grouping texts in chunks of {max_seq_length}", + ) + else: + tokenized_datasets = tokenized_datasets.map( + group_texts, + batched=True, + ) if training_args.do_train: if "train" not in tokenized_datasets: