Add magic method to our TF models to convert datasets with column inference (#17160)
* Add method to call to_tf_dataset() with column inference * Add test for dataset creation * Add a default arg for data collator * Fix test * Fix call with non-dev version of datasets * Test correct column removal too * make fixup * More tests to make sure we remove unwanted columns * Fix test to avoid predicting on unbuilt models * Fix test to avoid predicting on unbuilt models * Fix test to remove unwanted head mask columns from inputs * Stop pushing your debug breakpoints to the main repo of the $2bn company you work for * Skip the test in convnext because no grouped conv support * Drop bools from the dataset dict * Make style * Skip the training test for models whose input dicts don't give us labels * Skip transformerXL in the test because it doesn't return a simple loss * Skip TFTapas because of some odd NaN losses * make style * make fixup * Add docstring * fixup * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Remove breakpoint from tests * Fix assert, add requires_backends * Protect tokenizer import with if TYPE_CHECKING * make fixup * Add noqa, more fixup * More rearranging for ~* aesthetics *~ * Adding defaults for shuffle and batch_size to match to_tf_dataset() * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -22,7 +22,7 @@ import pickle
|
||||
import re
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
@@ -35,6 +35,7 @@ from tensorflow.python.keras.saving import hdf5_format
|
||||
from huggingface_hub import Repository, list_repo_files
|
||||
from requests import HTTPError
|
||||
|
||||
from . import DataCollatorWithPadding, DefaultDataCollator
|
||||
from .activations_tf import get_tf_activation
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
@@ -58,9 +59,14 @@ from .utils import (
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
logging,
|
||||
requires_backends,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import PreTrainedTokenizerBase
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
tf_logger = tf.get_logger()
|
||||
|
||||
@@ -892,6 +898,94 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
# set it directly, but the user can pass it to fit().
|
||||
return {"epoch": extra_data["epoch"]}
|
||||
|
||||
def prepare_tf_dataset(
|
||||
self,
|
||||
dataset: "datasets.Dataset", # noqa:F821
|
||||
batch_size: int = 8,
|
||||
shuffle: bool = True,
|
||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||
collate_fn: Optional[Callable] = None,
|
||||
collate_fn_args: Optional[Dict[str, Any]] = None,
|
||||
drop_remainder: Optional[bool] = None,
|
||||
prefetch: bool = True,
|
||||
):
|
||||
"""
|
||||
Wraps a HuggingFace `datasets.Dataset` as a `tf.data.Dataset` with collation and batching. This method is
|
||||
designed to create a "ready-to-use" dataset that can be passed directly to Keras methods like `fit()` without
|
||||
further modification. The method will drop columns from the dataset if they don't match input names for the
|
||||
model. If you want to specify the column names to return rather than using the names that match this model, we
|
||||
recommend using `Dataset.to_tf_dataset()` instead.
|
||||
|
||||
Args:
|
||||
dataset (`Any`):
|
||||
A `datasets.Dataset` to be wrapped as a `tf.data.Dataset`.
|
||||
batch_size (`int`, defaults to 8):
|
||||
The size of batches to return.
|
||||
shuffle (`bool`, defaults to `True`):
|
||||
Whether to return samples from the dataset in random order. Usually `True` for training datasets and
|
||||
`False` for validation/test datasets.
|
||||
tokenizer ([`PreTrainedTokenizerBase`], *optional*):
|
||||
A `PreTrainedTokenizer` that will be used to pad samples to create batches. Has no effect if a specific
|
||||
`collate_fn` is passed instead.
|
||||
collate_fn (`Callable`, *optional*):
|
||||
A function that collates samples from the dataset into a single batch. Defaults to
|
||||
`DefaultDataCollator` if no `tokenizer` is supplied or `DataCollatorWithPadding` if a `tokenizer` is
|
||||
passed.
|
||||
collate_fn_args (`Dict[str, Any]`, *optional*):
|
||||
A dict of arguments to pass to the `collate_fn` alongside the list of samples.
|
||||
drop_remainder (`bool`, *optional*):
|
||||
Whether to drop the final batch, if the batch_size does not evenly divide the dataset length. Defaults
|
||||
to the same setting as `shuffle`.
|
||||
prefetch (`bool`, defaults to `True`):
|
||||
Whether to add prefetching to the end of the `tf.data` pipeline. This is almost always beneficial for
|
||||
performance, but can be disabled in edge cases.
|
||||
|
||||
|
||||
Returns:
|
||||
`Dataset`: A `tf.data.Dataset` which is ready to pass to the Keras API.
|
||||
"""
|
||||
requires_backends(self, ["datasets"])
|
||||
import datasets
|
||||
|
||||
if collate_fn is None:
|
||||
if tokenizer is None:
|
||||
collate_fn = DefaultDataCollator(return_tensors="tf")
|
||||
else:
|
||||
collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
|
||||
if collate_fn_args is None:
|
||||
collate_fn_args = dict()
|
||||
|
||||
if not isinstance(dataset, datasets.Dataset):
|
||||
raise TypeError("Dataset argument should be a datasets.Dataset!")
|
||||
model_inputs = list(dict(inspect.signature(self.call).parameters).keys())
|
||||
model_labels = find_labels(self.__class__)
|
||||
unwanted_columns = [
|
||||
feature
|
||||
for feature in dataset.features
|
||||
if feature not in model_inputs and feature not in ("label_ids", "label")
|
||||
]
|
||||
dataset = dataset.remove_columns(unwanted_columns)
|
||||
output_signature, _ = dataset._get_output_signature(
|
||||
dataset,
|
||||
batch_size=None,
|
||||
collate_fn=collate_fn,
|
||||
collate_fn_args=collate_fn_args,
|
||||
)
|
||||
output_columns = list(output_signature.keys())
|
||||
feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels]
|
||||
label_cols = [col for col in output_columns if col in model_labels]
|
||||
tf_dataset = dataset.to_tf_dataset(
|
||||
columns=feature_cols,
|
||||
label_cols=label_cols,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
drop_remainder=drop_remainder,
|
||||
collate_fn=collate_fn,
|
||||
collate_fn_args=collate_fn_args,
|
||||
prefetch=prefetch,
|
||||
)
|
||||
return tf_dataset
|
||||
|
||||
def compile(
|
||||
self,
|
||||
optimizer="rmsprop",
|
||||
|
||||
@@ -174,6 +174,13 @@ class TFConvNextModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
def test_attention_outputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
|
||||
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
|
||||
)
|
||||
def test_dataset_conversion(self):
|
||||
super().test_dataset_conversion()
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
|
||||
@@ -498,6 +498,10 @@ class TFTapasModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="The default test gets NaN losses with the test-generated inputs")
|
||||
def test_dataset_conversion(self):
|
||||
pass
|
||||
|
||||
|
||||
def prepare_tapas_single_inputs_for_inference():
|
||||
# Here we prepare a single table-question pair to test TAPAS inference on:
|
||||
|
||||
@@ -216,6 +216,10 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
model = TFTransfoXLModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip(reason="This model doesn't play well with fit() due to not returning a single loss.")
|
||||
def test_dataset_conversion(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
@@ -25,6 +25,8 @@ import unittest.mock as mock
|
||||
from importlib import import_module
|
||||
from typing import List, Tuple
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from huggingface_hub import delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
@@ -1509,6 +1511,56 @@ class TFModelTesterMixin:
|
||||
observed_main_input_name = list(model_signature.parameters.keys())[1]
|
||||
self.assertEqual(model_class.main_input_name, observed_main_input_name)
|
||||
|
||||
def test_dataset_conversion(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=False)
|
||||
tf_inputs_dict = {
|
||||
key: val
|
||||
for key, val in tf_inputs_dict.items()
|
||||
if "head_mask" not in key and isinstance(val, tf.Tensor)
|
||||
}
|
||||
tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor
|
||||
input_dataset = Dataset.from_dict(tf_inputs_dict)
|
||||
tf_dataset = model.prepare_tf_dataset(
|
||||
input_dataset, batch_size=len(input_dataset), drop_remainder=False, shuffle=False
|
||||
)
|
||||
test_batch = next(iter(tf_dataset))
|
||||
if isinstance(test_batch, tf.Tensor):
|
||||
self.assertEqual(len(test_batch), len(input_dataset)) # Assert we didn't lose any data
|
||||
else:
|
||||
# Assert we discarded the unwanted extra column but kept everything else
|
||||
self.assertEqual(len(test_batch), len(input_dataset.features) - 1)
|
||||
self.assertNotIn("extra_unwanted_column", test_batch)
|
||||
for tensor in test_batch.values():
|
||||
self.assertTrue(isinstance(tensor, tf.Tensor))
|
||||
self.assertEqual(len(tensor), len(input_dataset)) # Assert we didn't lose any data
|
||||
model(test_batch, training=False)
|
||||
|
||||
if "labels" in inspect.signature(model_class.call).parameters.keys():
|
||||
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
if "labels" not in tf_inputs_dict:
|
||||
return # This model isn't giving us labels after all, don't try training with it
|
||||
tf_inputs_dict = {key: val for key, val in tf_inputs_dict.items() if "head_mask" not in key}
|
||||
tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor
|
||||
input_dataset = Dataset.from_dict(tf_inputs_dict)
|
||||
tf_dataset = model.prepare_tf_dataset(
|
||||
input_dataset, batch_size=len(input_dataset), drop_remainder=False, shuffle=False
|
||||
)
|
||||
test_batch, test_batch_labels = next(iter(tf_dataset))
|
||||
self.assertGreater(len(test_batch_labels), 0) # Assert the labels are present
|
||||
feature_columns = 1 if isinstance(test_batch, tf.Tensor) else len(test_batch)
|
||||
label_columns = 1 if isinstance(test_batch_labels, tf.Tensor) else len(test_batch_labels)
|
||||
# Assert we discarded the unwanted extra column but kept everything else
|
||||
self.assertEqual(feature_columns + label_columns, len(input_dataset.features) - 1)
|
||||
if isinstance(test_batch, dict):
|
||||
self.assertNotIn("extra_unwanted_column", test_batch)
|
||||
if isinstance(test_batch_labels, dict):
|
||||
self.assertNotIn("extra_unwanted_column", test_batch_labels)
|
||||
model.compile(optimizer="sgd", run_eagerly=True)
|
||||
model.train_on_batch(test_batch, test_batch_labels)
|
||||
|
||||
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
||||
# special tokens cannot be bad tokens
|
||||
special_tokens = []
|
||||
|
||||
Reference in New Issue
Block a user