Fix bugs in pytorch example run_clm when streaming is enabled (#39286)
This commit is contained in:
@@ -31,7 +31,7 @@ from typing import Optional
|
||||
import datasets
|
||||
import evaluate
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets import IterableDataset, IterableDatasetDict, load_dataset
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
@@ -225,6 +225,45 @@ class DataTrainingArguments:
|
||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
||||
|
||||
|
||||
def split_streaming_dataset(
|
||||
full_streaming_dataset,
|
||||
validation_percentage: int = 5,
|
||||
) -> IterableDatasetDict:
|
||||
"""
|
||||
Splits a streaming dataset into
|
||||
training and validation IterableDatasets, and supports methods like .map(), .filter(),
|
||||
.take() and properties like .features on the resulting streams.
|
||||
|
||||
Args:
|
||||
full_streaming_dataset (Dataset): The name of the dataset to load (e.g., "HuggingFaceFW/fineweb").
|
||||
validation_percentage (int): The proportion of the dataset to be used for validation split.
|
||||
|
||||
Returns:
|
||||
IterableDatasetDict: An IterableDatasetDict containing two IterableDataset objects: (train_stream, validation_stream).
|
||||
"""
|
||||
if not (0 < validation_percentage < 100):
|
||||
raise ValueError(
|
||||
f"validation_percentage must be between 0 and 100 (exclusive). Passed: {validation_percentage}"
|
||||
)
|
||||
|
||||
def split_generator(is_train: bool):
|
||||
for i, example in enumerate(full_streaming_dataset):
|
||||
if is_train:
|
||||
if i % 100 > validation_percentage:
|
||||
yield example
|
||||
else:
|
||||
if i % 100 < validation_percentage:
|
||||
yield example
|
||||
|
||||
features = full_streaming_dataset.features
|
||||
train_stream = IterableDataset.from_generator(split_generator, gen_kwargs={"is_train": True}, features=features)
|
||||
validation_stream = IterableDataset.from_generator(
|
||||
split_generator, gen_kwargs={"is_train": False}, features=features
|
||||
)
|
||||
|
||||
return IterableDatasetDict({"train": train_stream, "validation": validation_stream})
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
@@ -305,24 +344,36 @@ def main():
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
)
|
||||
if "validation" not in raw_datasets.keys():
|
||||
raw_datasets["validation"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"train[:{data_args.validation_split_percentage}%]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.token,
|
||||
streaming=data_args.streaming,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
)
|
||||
raw_datasets["train"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"train[{data_args.validation_split_percentage}%:]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.token,
|
||||
streaming=data_args.streaming,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
)
|
||||
if data_args.streaming:
|
||||
dataset_stream = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split="train",
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.token,
|
||||
streaming=data_args.streaming,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
)
|
||||
raw_datasets = split_streaming_dataset(dataset_stream, data_args.validation_split_percentage)
|
||||
else:
|
||||
raw_datasets["validation"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"train[:{data_args.validation_split_percentage}%]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.token,
|
||||
streaming=data_args.streaming,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
)
|
||||
raw_datasets["train"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"train[{data_args.validation_split_percentage}%:]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.token,
|
||||
streaming=data_args.streaming,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
dataset_args = {}
|
||||
@@ -347,22 +398,34 @@ def main():
|
||||
)
|
||||
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
||||
if "validation" not in raw_datasets.keys():
|
||||
raw_datasets["validation"] = load_dataset(
|
||||
extension,
|
||||
data_files=data_files,
|
||||
split=f"train[:{data_args.validation_split_percentage}%]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.token,
|
||||
**dataset_args,
|
||||
)
|
||||
raw_datasets["train"] = load_dataset(
|
||||
extension,
|
||||
data_files=data_files,
|
||||
split=f"train[{data_args.validation_split_percentage}%:]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.token,
|
||||
**dataset_args,
|
||||
)
|
||||
if data_args.streaming:
|
||||
dataset_stream = load_dataset(
|
||||
extension,
|
||||
data_files=data_files,
|
||||
split="train",
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.token,
|
||||
**dataset_args,
|
||||
)
|
||||
raw_datasets = split_streaming_dataset(dataset_stream, data_args.validation_split_percentage)
|
||||
else:
|
||||
raw_datasets["validation"] = load_dataset(
|
||||
extension,
|
||||
data_files=data_files,
|
||||
split=f"train[:{data_args.validation_split_percentage}%]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.token,
|
||||
**dataset_args,
|
||||
)
|
||||
|
||||
raw_datasets["train"] = load_dataset(
|
||||
extension,
|
||||
data_files=data_files,
|
||||
split=f"train[{data_args.validation_split_percentage}%:]",
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.token,
|
||||
**dataset_args,
|
||||
)
|
||||
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.
|
||||
@@ -541,16 +604,22 @@ def main():
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = lm_datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
||||
train_dataset = train_dataset.select(range(max_train_samples))
|
||||
if data_args.streaming:
|
||||
train_dataset = train_dataset.take(data_args.max_train_samples)
|
||||
else:
|
||||
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
||||
train_dataset = train_dataset.select(range(max_train_samples))
|
||||
|
||||
if training_args.do_eval:
|
||||
if "validation" not in tokenized_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = lm_datasets["validation"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
||||
if data_args.streaming:
|
||||
eval_dataset = eval_dataset.take(data_args.max_eval_samples)
|
||||
else:
|
||||
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
||||
|
||||
def preprocess_logits_for_metrics(logits, labels):
|
||||
if isinstance(logits, tuple):
|
||||
@@ -599,7 +668,10 @@ def main():
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
if data_args.streaming:
|
||||
metrics["train_samples"] = max_train_samples
|
||||
else:
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
@@ -612,7 +684,11 @@ def main():
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
if data_args.streaming:
|
||||
metrics["eval_samples"] = max_eval_samples
|
||||
else:
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
try:
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
except OverflowError:
|
||||
|
||||
Reference in New Issue
Block a user