Fix bugs in pytorch example run_clm when streaming is enabled (#39286)
This commit is contained in:
@@ -31,7 +31,7 @@ from typing import Optional
|
|||||||
import datasets
|
import datasets
|
||||||
import evaluate
|
import evaluate
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import IterableDataset, IterableDatasetDict, load_dataset
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -225,6 +225,45 @@ class DataTrainingArguments:
|
|||||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
||||||
|
|
||||||
|
|
||||||
|
def split_streaming_dataset(
|
||||||
|
full_streaming_dataset,
|
||||||
|
validation_percentage: int = 5,
|
||||||
|
) -> IterableDatasetDict:
|
||||||
|
"""
|
||||||
|
Splits a streaming dataset into
|
||||||
|
training and validation IterableDatasets, and supports methods like .map(), .filter(),
|
||||||
|
.take() and properties like .features on the resulting streams.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
full_streaming_dataset (Dataset): The name of the dataset to load (e.g., "HuggingFaceFW/fineweb").
|
||||||
|
validation_percentage (int): The proportion of the dataset to be used for validation split.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IterableDatasetDict: An IterableDatasetDict containing two IterableDataset objects: (train_stream, validation_stream).
|
||||||
|
"""
|
||||||
|
if not (0 < validation_percentage < 100):
|
||||||
|
raise ValueError(
|
||||||
|
f"validation_percentage must be between 0 and 100 (exclusive). Passed: {validation_percentage}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def split_generator(is_train: bool):
|
||||||
|
for i, example in enumerate(full_streaming_dataset):
|
||||||
|
if is_train:
|
||||||
|
if i % 100 > validation_percentage:
|
||||||
|
yield example
|
||||||
|
else:
|
||||||
|
if i % 100 < validation_percentage:
|
||||||
|
yield example
|
||||||
|
|
||||||
|
features = full_streaming_dataset.features
|
||||||
|
train_stream = IterableDataset.from_generator(split_generator, gen_kwargs={"is_train": True}, features=features)
|
||||||
|
validation_stream = IterableDataset.from_generator(
|
||||||
|
split_generator, gen_kwargs={"is_train": False}, features=features
|
||||||
|
)
|
||||||
|
|
||||||
|
return IterableDatasetDict({"train": train_stream, "validation": validation_stream})
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# See all possible arguments in src/transformers/training_args.py
|
# See all possible arguments in src/transformers/training_args.py
|
||||||
# or by passing the --help flag to this script.
|
# or by passing the --help flag to this script.
|
||||||
@@ -305,6 +344,18 @@ def main():
|
|||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
)
|
)
|
||||||
if "validation" not in raw_datasets.keys():
|
if "validation" not in raw_datasets.keys():
|
||||||
|
if data_args.streaming:
|
||||||
|
dataset_stream = load_dataset(
|
||||||
|
data_args.dataset_name,
|
||||||
|
data_args.dataset_config_name,
|
||||||
|
split="train",
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
token=model_args.token,
|
||||||
|
streaming=data_args.streaming,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
)
|
||||||
|
raw_datasets = split_streaming_dataset(dataset_stream, data_args.validation_split_percentage)
|
||||||
|
else:
|
||||||
raw_datasets["validation"] = load_dataset(
|
raw_datasets["validation"] = load_dataset(
|
||||||
data_args.dataset_name,
|
data_args.dataset_name,
|
||||||
data_args.dataset_config_name,
|
data_args.dataset_config_name,
|
||||||
@@ -347,6 +398,17 @@ def main():
|
|||||||
)
|
)
|
||||||
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
||||||
if "validation" not in raw_datasets.keys():
|
if "validation" not in raw_datasets.keys():
|
||||||
|
if data_args.streaming:
|
||||||
|
dataset_stream = load_dataset(
|
||||||
|
extension,
|
||||||
|
data_files=data_files,
|
||||||
|
split="train",
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
token=model_args.token,
|
||||||
|
**dataset_args,
|
||||||
|
)
|
||||||
|
raw_datasets = split_streaming_dataset(dataset_stream, data_args.validation_split_percentage)
|
||||||
|
else:
|
||||||
raw_datasets["validation"] = load_dataset(
|
raw_datasets["validation"] = load_dataset(
|
||||||
extension,
|
extension,
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
@@ -355,6 +417,7 @@ def main():
|
|||||||
token=model_args.token,
|
token=model_args.token,
|
||||||
**dataset_args,
|
**dataset_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_datasets["train"] = load_dataset(
|
raw_datasets["train"] = load_dataset(
|
||||||
extension,
|
extension,
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
@@ -541,6 +604,9 @@ def main():
|
|||||||
raise ValueError("--do_train requires a train dataset")
|
raise ValueError("--do_train requires a train dataset")
|
||||||
train_dataset = lm_datasets["train"]
|
train_dataset = lm_datasets["train"]
|
||||||
if data_args.max_train_samples is not None:
|
if data_args.max_train_samples is not None:
|
||||||
|
if data_args.streaming:
|
||||||
|
train_dataset = train_dataset.take(data_args.max_train_samples)
|
||||||
|
else:
|
||||||
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
||||||
train_dataset = train_dataset.select(range(max_train_samples))
|
train_dataset = train_dataset.select(range(max_train_samples))
|
||||||
|
|
||||||
@@ -549,6 +615,9 @@ def main():
|
|||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_dataset = lm_datasets["validation"]
|
eval_dataset = lm_datasets["validation"]
|
||||||
if data_args.max_eval_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
|
if data_args.streaming:
|
||||||
|
eval_dataset = eval_dataset.take(data_args.max_eval_samples)
|
||||||
|
else:
|
||||||
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||||
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
||||||
|
|
||||||
@@ -599,6 +668,9 @@ def main():
|
|||||||
max_train_samples = (
|
max_train_samples = (
|
||||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||||
)
|
)
|
||||||
|
if data_args.streaming:
|
||||||
|
metrics["train_samples"] = max_train_samples
|
||||||
|
else:
|
||||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("train", metrics)
|
trainer.log_metrics("train", metrics)
|
||||||
@@ -612,7 +684,11 @@ def main():
|
|||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
|
|
||||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
|
if data_args.streaming:
|
||||||
|
metrics["eval_samples"] = max_eval_samples
|
||||||
|
else:
|
||||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
perplexity = math.exp(metrics["eval_loss"])
|
perplexity = math.exp(metrics["eval_loss"])
|
||||||
except OverflowError:
|
except OverflowError:
|
||||||
|
|||||||
Reference in New Issue
Block a user