Fix bugs in pytorch example run_clm when streaming is enabled (#39286)

This commit is contained in:
Hosein Rezaei
2025-07-15 14:37:28 +01:00
committed by GitHub
parent a989bf8d84
commit d8e05951b8

View File

@@ -31,7 +31,7 @@ from typing import Optional
import datasets import datasets
import evaluate import evaluate
import torch import torch
from datasets import load_dataset from datasets import IterableDataset, IterableDatasetDict, load_dataset
import transformers import transformers
from transformers import ( from transformers import (
@@ -225,6 +225,45 @@ class DataTrainingArguments:
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
def split_streaming_dataset(
full_streaming_dataset,
validation_percentage: int = 5,
) -> IterableDatasetDict:
"""
Splits a streaming dataset into
training and validation IterableDatasets, and supports methods like .map(), .filter(),
.take() and properties like .features on the resulting streams.
Args:
full_streaming_dataset (Dataset): The name of the dataset to load (e.g., "HuggingFaceFW/fineweb").
validation_percentage (int): The proportion of the dataset to be used for validation split.
Returns:
IterableDatasetDict: An IterableDatasetDict containing two IterableDataset objects: (train_stream, validation_stream).
"""
if not (0 < validation_percentage < 100):
raise ValueError(
f"validation_percentage must be between 0 and 100 (exclusive). Passed: {validation_percentage}"
)
def split_generator(is_train: bool):
for i, example in enumerate(full_streaming_dataset):
if is_train:
if i % 100 > validation_percentage:
yield example
else:
if i % 100 < validation_percentage:
yield example
features = full_streaming_dataset.features
train_stream = IterableDataset.from_generator(split_generator, gen_kwargs={"is_train": True}, features=features)
validation_stream = IterableDataset.from_generator(
split_generator, gen_kwargs={"is_train": False}, features=features
)
return IterableDatasetDict({"train": train_stream, "validation": validation_stream})
def main(): def main():
# See all possible arguments in src/transformers/training_args.py # See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script. # or by passing the --help flag to this script.
@@ -305,24 +344,36 @@ def main():
trust_remote_code=model_args.trust_remote_code, trust_remote_code=model_args.trust_remote_code,
) )
if "validation" not in raw_datasets.keys(): if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset( if data_args.streaming:
data_args.dataset_name, dataset_stream = load_dataset(
data_args.dataset_config_name, data_args.dataset_name,
split=f"train[:{data_args.validation_split_percentage}%]", data_args.dataset_config_name,
cache_dir=model_args.cache_dir, split="train",
token=model_args.token, cache_dir=model_args.cache_dir,
streaming=data_args.streaming, token=model_args.token,
trust_remote_code=model_args.trust_remote_code, streaming=data_args.streaming,
) trust_remote_code=model_args.trust_remote_code,
raw_datasets["train"] = load_dataset( )
data_args.dataset_name, raw_datasets = split_streaming_dataset(dataset_stream, data_args.validation_split_percentage)
data_args.dataset_config_name, else:
split=f"train[{data_args.validation_split_percentage}%:]", raw_datasets["validation"] = load_dataset(
cache_dir=model_args.cache_dir, data_args.dataset_name,
token=model_args.token, data_args.dataset_config_name,
streaming=data_args.streaming, split=f"train[:{data_args.validation_split_percentage}%]",
trust_remote_code=model_args.trust_remote_code, cache_dir=model_args.cache_dir,
) token=model_args.token,
streaming=data_args.streaming,
trust_remote_code=model_args.trust_remote_code,
)
raw_datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
token=model_args.token,
streaming=data_args.streaming,
trust_remote_code=model_args.trust_remote_code,
)
else: else:
data_files = {} data_files = {}
dataset_args = {} dataset_args = {}
@@ -347,22 +398,34 @@ def main():
) )
# If no validation data is there, validation_split_percentage will be used to divide the dataset. # If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys(): if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset( if data_args.streaming:
extension, dataset_stream = load_dataset(
data_files=data_files, extension,
split=f"train[:{data_args.validation_split_percentage}%]", data_files=data_files,
cache_dir=model_args.cache_dir, split="train",
token=model_args.token, cache_dir=model_args.cache_dir,
**dataset_args, token=model_args.token,
) **dataset_args,
raw_datasets["train"] = load_dataset( )
extension, raw_datasets = split_streaming_dataset(dataset_stream, data_args.validation_split_percentage)
data_files=data_files, else:
split=f"train[{data_args.validation_split_percentage}%:]", raw_datasets["validation"] = load_dataset(
cache_dir=model_args.cache_dir, extension,
token=model_args.token, data_files=data_files,
**dataset_args, split=f"train[:{data_args.validation_split_percentage}%]",
) cache_dir=model_args.cache_dir,
token=model_args.token,
**dataset_args,
)
raw_datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
token=model_args.token,
**dataset_args,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets. # https://huggingface.co/docs/datasets/loading_datasets.
@@ -541,16 +604,22 @@ def main():
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
train_dataset = lm_datasets["train"] train_dataset = lm_datasets["train"]
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples) if data_args.streaming:
train_dataset = train_dataset.select(range(max_train_samples)) train_dataset = train_dataset.take(data_args.max_train_samples)
else:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
if training_args.do_eval: if training_args.do_eval:
if "validation" not in tokenized_datasets: if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
eval_dataset = lm_datasets["validation"] eval_dataset = lm_datasets["validation"]
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) if data_args.streaming:
eval_dataset = eval_dataset.select(range(max_eval_samples)) eval_dataset = eval_dataset.take(data_args.max_eval_samples)
else:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
def preprocess_logits_for_metrics(logits, labels): def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple): if isinstance(logits, tuple):
@@ -599,7 +668,10 @@ def main():
max_train_samples = ( max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
) )
metrics["train_samples"] = min(max_train_samples, len(train_dataset)) if data_args.streaming:
metrics["train_samples"] = max_train_samples
else:
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
trainer.log_metrics("train", metrics) trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics) trainer.save_metrics("train", metrics)
@@ -612,7 +684,11 @@ def main():
metrics = trainer.evaluate() metrics = trainer.evaluate()
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) if data_args.streaming:
metrics["eval_samples"] = max_eval_samples
else:
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
try: try:
perplexity = math.exp(metrics["eval_loss"]) perplexity = math.exp(metrics["eval_loss"])
except OverflowError: except OverflowError: