From d8e05951b8efd4880acca9a3f291e8b65841a86d Mon Sep 17 00:00:00 2001 From: Hosein Rezaei Date: Tue, 15 Jul 2025 14:37:28 +0100 Subject: [PATCH] Fix bugs in pytorch example run_clm when streaming is enabled (#39286) --- examples/pytorch/language-modeling/run_clm.py | 158 +++++++++++++----- 1 file changed, 117 insertions(+), 41 deletions(-) diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index dbd0e6e0fa..dad24cd7ef 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -31,7 +31,7 @@ from typing import Optional import datasets import evaluate import torch -from datasets import load_dataset +from datasets import IterableDataset, IterableDatasetDict, load_dataset import transformers from transformers import ( @@ -225,6 +225,45 @@ class DataTrainingArguments: assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." +def split_streaming_dataset( + full_streaming_dataset, + validation_percentage: int = 5, +) -> IterableDatasetDict: + """ + Splits a streaming dataset into + training and validation IterableDatasets, and supports methods like .map(), .filter(), + .take() and properties like .features on the resulting streams. + + Args: + full_streaming_dataset (Dataset): The name of the dataset to load (e.g., "HuggingFaceFW/fineweb"). + validation_percentage (int): The proportion of the dataset to be used for validation split. + + Returns: + IterableDatasetDict: An IterableDatasetDict containing two IterableDataset objects: (train_stream, validation_stream). + """ + if not (0 < validation_percentage < 100): + raise ValueError( + f"validation_percentage must be between 0 and 100 (exclusive). Passed: {validation_percentage}" + ) + + def split_generator(is_train: bool): + for i, example in enumerate(full_streaming_dataset): + if is_train: + if i % 100 > validation_percentage: + yield example + else: + if i % 100 < validation_percentage: + yield example + + features = full_streaming_dataset.features + train_stream = IterableDataset.from_generator(split_generator, gen_kwargs={"is_train": True}, features=features) + validation_stream = IterableDataset.from_generator( + split_generator, gen_kwargs={"is_train": False}, features=features + ) + + return IterableDatasetDict({"train": train_stream, "validation": validation_stream}) + + def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. @@ -305,24 +344,36 @@ def main(): trust_remote_code=model_args.trust_remote_code, ) if "validation" not in raw_datasets.keys(): - raw_datasets["validation"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[:{data_args.validation_split_percentage}%]", - cache_dir=model_args.cache_dir, - token=model_args.token, - streaming=data_args.streaming, - trust_remote_code=model_args.trust_remote_code, - ) - raw_datasets["train"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[{data_args.validation_split_percentage}%:]", - cache_dir=model_args.cache_dir, - token=model_args.token, - streaming=data_args.streaming, - trust_remote_code=model_args.trust_remote_code, - ) + if data_args.streaming: + dataset_stream = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split="train", + cache_dir=model_args.cache_dir, + token=model_args.token, + streaming=data_args.streaming, + trust_remote_code=model_args.trust_remote_code, + ) + raw_datasets = split_streaming_dataset(dataset_stream, data_args.validation_split_percentage) + else: + raw_datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + token=model_args.token, + streaming=data_args.streaming, + trust_remote_code=model_args.trust_remote_code, + ) + raw_datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + token=model_args.token, + streaming=data_args.streaming, + trust_remote_code=model_args.trust_remote_code, + ) else: data_files = {} dataset_args = {} @@ -347,22 +398,34 @@ def main(): ) # If no validation data is there, validation_split_percentage will be used to divide the dataset. if "validation" not in raw_datasets.keys(): - raw_datasets["validation"] = load_dataset( - extension, - data_files=data_files, - split=f"train[:{data_args.validation_split_percentage}%]", - cache_dir=model_args.cache_dir, - token=model_args.token, - **dataset_args, - ) - raw_datasets["train"] = load_dataset( - extension, - data_files=data_files, - split=f"train[{data_args.validation_split_percentage}%:]", - cache_dir=model_args.cache_dir, - token=model_args.token, - **dataset_args, - ) + if data_args.streaming: + dataset_stream = load_dataset( + extension, + data_files=data_files, + split="train", + cache_dir=model_args.cache_dir, + token=model_args.token, + **dataset_args, + ) + raw_datasets = split_streaming_dataset(dataset_stream, data_args.validation_split_percentage) + else: + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + token=model_args.token, + **dataset_args, + ) + + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + token=model_args.token, + **dataset_args, + ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets. @@ -541,16 +604,22 @@ def main(): raise ValueError("--do_train requires a train dataset") train_dataset = lm_datasets["train"] if data_args.max_train_samples is not None: - max_train_samples = min(len(train_dataset), data_args.max_train_samples) - train_dataset = train_dataset.select(range(max_train_samples)) + if data_args.streaming: + train_dataset = train_dataset.take(data_args.max_train_samples) + else: + max_train_samples = min(len(train_dataset), data_args.max_train_samples) + train_dataset = train_dataset.select(range(max_train_samples)) if training_args.do_eval: if "validation" not in tokenized_datasets: raise ValueError("--do_eval requires a validation dataset") eval_dataset = lm_datasets["validation"] if data_args.max_eval_samples is not None: - max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) - eval_dataset = eval_dataset.select(range(max_eval_samples)) + if data_args.streaming: + eval_dataset = eval_dataset.take(data_args.max_eval_samples) + else: + max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) + eval_dataset = eval_dataset.select(range(max_eval_samples)) def preprocess_logits_for_metrics(logits, labels): if isinstance(logits, tuple): @@ -599,7 +668,10 @@ def main(): max_train_samples = ( data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) ) - metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + if data_args.streaming: + metrics["train_samples"] = max_train_samples + else: + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) @@ -612,7 +684,11 @@ def main(): metrics = trainer.evaluate() max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) - metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + if data_args.streaming: + metrics["eval_samples"] = max_eval_samples + else: + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + try: perplexity = math.exp(metrics["eval_loss"]) except OverflowError: