Fix bugs in pytorch example run_clm when streaming is enabled (#39286)

This commit is contained in:
Hosein Rezaei
2025-07-15 14:37:28 +01:00
committed by GitHub
parent a989bf8d84
commit d8e05951b8

View File

@@ -31,7 +31,7 @@ from typing import Optional
import datasets
import evaluate
import torch
from datasets import load_dataset
from datasets import IterableDataset, IterableDatasetDict, load_dataset
import transformers
from transformers import (
@@ -225,6 +225,45 @@ class DataTrainingArguments:
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
def split_streaming_dataset(
full_streaming_dataset,
validation_percentage: int = 5,
) -> IterableDatasetDict:
"""
Splits a streaming dataset into
training and validation IterableDatasets, and supports methods like .map(), .filter(),
.take() and properties like .features on the resulting streams.
Args:
full_streaming_dataset (Dataset): The name of the dataset to load (e.g., "HuggingFaceFW/fineweb").
validation_percentage (int): The proportion of the dataset to be used for validation split.
Returns:
IterableDatasetDict: An IterableDatasetDict containing two IterableDataset objects: (train_stream, validation_stream).
"""
if not (0 < validation_percentage < 100):
raise ValueError(
f"validation_percentage must be between 0 and 100 (exclusive). Passed: {validation_percentage}"
)
def split_generator(is_train: bool):
for i, example in enumerate(full_streaming_dataset):
if is_train:
if i % 100 > validation_percentage:
yield example
else:
if i % 100 < validation_percentage:
yield example
features = full_streaming_dataset.features
train_stream = IterableDataset.from_generator(split_generator, gen_kwargs={"is_train": True}, features=features)
validation_stream = IterableDataset.from_generator(
split_generator, gen_kwargs={"is_train": False}, features=features
)
return IterableDatasetDict({"train": train_stream, "validation": validation_stream})
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
@@ -305,6 +344,18 @@ def main():
trust_remote_code=model_args.trust_remote_code,
)
if "validation" not in raw_datasets.keys():
if data_args.streaming:
dataset_stream = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split="train",
cache_dir=model_args.cache_dir,
token=model_args.token,
streaming=data_args.streaming,
trust_remote_code=model_args.trust_remote_code,
)
raw_datasets = split_streaming_dataset(dataset_stream, data_args.validation_split_percentage)
else:
raw_datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
@@ -347,6 +398,17 @@ def main():
)
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys():
if data_args.streaming:
dataset_stream = load_dataset(
extension,
data_files=data_files,
split="train",
cache_dir=model_args.cache_dir,
token=model_args.token,
**dataset_args,
)
raw_datasets = split_streaming_dataset(dataset_stream, data_args.validation_split_percentage)
else:
raw_datasets["validation"] = load_dataset(
extension,
data_files=data_files,
@@ -355,6 +417,7 @@ def main():
token=model_args.token,
**dataset_args,
)
raw_datasets["train"] = load_dataset(
extension,
data_files=data_files,
@@ -541,6 +604,9 @@ def main():
raise ValueError("--do_train requires a train dataset")
train_dataset = lm_datasets["train"]
if data_args.max_train_samples is not None:
if data_args.streaming:
train_dataset = train_dataset.take(data_args.max_train_samples)
else:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
@@ -549,6 +615,9 @@ def main():
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = lm_datasets["validation"]
if data_args.max_eval_samples is not None:
if data_args.streaming:
eval_dataset = eval_dataset.take(data_args.max_eval_samples)
else:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
@@ -599,6 +668,9 @@ def main():
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
if data_args.streaming:
metrics["train_samples"] = max_train_samples
else:
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
trainer.log_metrics("train", metrics)
@@ -612,7 +684,11 @@ def main():
metrics = trainer.evaluate()
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
if data_args.streaming:
metrics["eval_samples"] = max_eval_samples
else:
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError: