Add magic method to our TF models to convert datasets with column inference (#17160)
* Add method to call to_tf_dataset() with column inference * Add test for dataset creation * Add a default arg for data collator * Fix test * Fix call with non-dev version of datasets * Test correct column removal too * make fixup * More tests to make sure we remove unwanted columns * Fix test to avoid predicting on unbuilt models * Fix test to avoid predicting on unbuilt models * Fix test to remove unwanted head mask columns from inputs * Stop pushing your debug breakpoints to the main repo of the $2bn company you work for * Skip the test in convnext because no grouped conv support * Drop bools from the dataset dict * Make style * Skip the training test for models whose input dicts don't give us labels * Skip transformerXL in the test because it doesn't return a simple loss * Skip TFTapas because of some odd NaN losses * make style * make fixup * Add docstring * fixup * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Remove breakpoint from tests * Fix assert, add requires_backends * Protect tokenizer import with if TYPE_CHECKING * make fixup * Add noqa, more fixup * More rearranging for ~* aesthetics *~ * Adding defaults for shuffle and batch_size to match to_tf_dataset() * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -22,7 +22,7 @@ import pickle
|
|||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -35,6 +35,7 @@ from tensorflow.python.keras.saving import hdf5_format
|
|||||||
from huggingface_hub import Repository, list_repo_files
|
from huggingface_hub import Repository, list_repo_files
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
|
|
||||||
|
from . import DataCollatorWithPadding, DefaultDataCollator
|
||||||
from .activations_tf import get_tf_activation
|
from .activations_tf import get_tf_activation
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
@@ -58,9 +59,14 @@ from .utils import (
|
|||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
logging,
|
logging,
|
||||||
|
requires_backends,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
tf_logger = tf.get_logger()
|
tf_logger = tf.get_logger()
|
||||||
|
|
||||||
@@ -892,6 +898,94 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
# set it directly, but the user can pass it to fit().
|
# set it directly, but the user can pass it to fit().
|
||||||
return {"epoch": extra_data["epoch"]}
|
return {"epoch": extra_data["epoch"]}
|
||||||
|
|
||||||
|
def prepare_tf_dataset(
|
||||||
|
self,
|
||||||
|
dataset: "datasets.Dataset", # noqa:F821
|
||||||
|
batch_size: int = 8,
|
||||||
|
shuffle: bool = True,
|
||||||
|
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||||
|
collate_fn: Optional[Callable] = None,
|
||||||
|
collate_fn_args: Optional[Dict[str, Any]] = None,
|
||||||
|
drop_remainder: Optional[bool] = None,
|
||||||
|
prefetch: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Wraps a HuggingFace `datasets.Dataset` as a `tf.data.Dataset` with collation and batching. This method is
|
||||||
|
designed to create a "ready-to-use" dataset that can be passed directly to Keras methods like `fit()` without
|
||||||
|
further modification. The method will drop columns from the dataset if they don't match input names for the
|
||||||
|
model. If you want to specify the column names to return rather than using the names that match this model, we
|
||||||
|
recommend using `Dataset.to_tf_dataset()` instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (`Any`):
|
||||||
|
A `datasets.Dataset` to be wrapped as a `tf.data.Dataset`.
|
||||||
|
batch_size (`int`, defaults to 8):
|
||||||
|
The size of batches to return.
|
||||||
|
shuffle (`bool`, defaults to `True`):
|
||||||
|
Whether to return samples from the dataset in random order. Usually `True` for training datasets and
|
||||||
|
`False` for validation/test datasets.
|
||||||
|
tokenizer ([`PreTrainedTokenizerBase`], *optional*):
|
||||||
|
A `PreTrainedTokenizer` that will be used to pad samples to create batches. Has no effect if a specific
|
||||||
|
`collate_fn` is passed instead.
|
||||||
|
collate_fn (`Callable`, *optional*):
|
||||||
|
A function that collates samples from the dataset into a single batch. Defaults to
|
||||||
|
`DefaultDataCollator` if no `tokenizer` is supplied or `DataCollatorWithPadding` if a `tokenizer` is
|
||||||
|
passed.
|
||||||
|
collate_fn_args (`Dict[str, Any]`, *optional*):
|
||||||
|
A dict of arguments to pass to the `collate_fn` alongside the list of samples.
|
||||||
|
drop_remainder (`bool`, *optional*):
|
||||||
|
Whether to drop the final batch, if the batch_size does not evenly divide the dataset length. Defaults
|
||||||
|
to the same setting as `shuffle`.
|
||||||
|
prefetch (`bool`, defaults to `True`):
|
||||||
|
Whether to add prefetching to the end of the `tf.data` pipeline. This is almost always beneficial for
|
||||||
|
performance, but can be disabled in edge cases.
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Dataset`: A `tf.data.Dataset` which is ready to pass to the Keras API.
|
||||||
|
"""
|
||||||
|
requires_backends(self, ["datasets"])
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
if collate_fn is None:
|
||||||
|
if tokenizer is None:
|
||||||
|
collate_fn = DefaultDataCollator(return_tensors="tf")
|
||||||
|
else:
|
||||||
|
collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
|
||||||
|
if collate_fn_args is None:
|
||||||
|
collate_fn_args = dict()
|
||||||
|
|
||||||
|
if not isinstance(dataset, datasets.Dataset):
|
||||||
|
raise TypeError("Dataset argument should be a datasets.Dataset!")
|
||||||
|
model_inputs = list(dict(inspect.signature(self.call).parameters).keys())
|
||||||
|
model_labels = find_labels(self.__class__)
|
||||||
|
unwanted_columns = [
|
||||||
|
feature
|
||||||
|
for feature in dataset.features
|
||||||
|
if feature not in model_inputs and feature not in ("label_ids", "label")
|
||||||
|
]
|
||||||
|
dataset = dataset.remove_columns(unwanted_columns)
|
||||||
|
output_signature, _ = dataset._get_output_signature(
|
||||||
|
dataset,
|
||||||
|
batch_size=None,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
collate_fn_args=collate_fn_args,
|
||||||
|
)
|
||||||
|
output_columns = list(output_signature.keys())
|
||||||
|
feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels]
|
||||||
|
label_cols = [col for col in output_columns if col in model_labels]
|
||||||
|
tf_dataset = dataset.to_tf_dataset(
|
||||||
|
columns=feature_cols,
|
||||||
|
label_cols=label_cols,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
drop_remainder=drop_remainder,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
collate_fn_args=collate_fn_args,
|
||||||
|
prefetch=prefetch,
|
||||||
|
)
|
||||||
|
return tf_dataset
|
||||||
|
|
||||||
def compile(
|
def compile(
|
||||||
self,
|
self,
|
||||||
optimizer="rmsprop",
|
optimizer="rmsprop",
|
||||||
|
|||||||
@@ -174,6 +174,13 @@ class TFConvNextModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
def test_attention_outputs(self):
|
def test_attention_outputs(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
|
||||||
|
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
|
||||||
|
)
|
||||||
|
def test_dataset_conversion(self):
|
||||||
|
super().test_dataset_conversion()
|
||||||
|
|
||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|||||||
@@ -498,6 +498,10 @@ class TFTapasModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skip(reason="The default test gets NaN losses with the test-generated inputs")
|
||||||
|
def test_dataset_conversion(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def prepare_tapas_single_inputs_for_inference():
|
def prepare_tapas_single_inputs_for_inference():
|
||||||
# Here we prepare a single table-question pair to test TAPAS inference on:
|
# Here we prepare a single table-question pair to test TAPAS inference on:
|
||||||
|
|||||||
@@ -216,6 +216,10 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
model = TFTransfoXLModel.from_pretrained(model_name)
|
model = TFTransfoXLModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
@unittest.skip(reason="This model doesn't play well with fit() due to not returning a single loss.")
|
||||||
|
def test_dataset_conversion(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ import unittest.mock as mock
|
|||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
from huggingface_hub import delete_repo, login
|
from huggingface_hub import delete_repo, login
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import is_tf_available, is_torch_available
|
from transformers import is_tf_available, is_torch_available
|
||||||
@@ -1509,6 +1511,56 @@ class TFModelTesterMixin:
|
|||||||
observed_main_input_name = list(model_signature.parameters.keys())[1]
|
observed_main_input_name = list(model_signature.parameters.keys())[1]
|
||||||
self.assertEqual(model_class.main_input_name, observed_main_input_name)
|
self.assertEqual(model_class.main_input_name, observed_main_input_name)
|
||||||
|
|
||||||
|
def test_dataset_conversion(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=False)
|
||||||
|
tf_inputs_dict = {
|
||||||
|
key: val
|
||||||
|
for key, val in tf_inputs_dict.items()
|
||||||
|
if "head_mask" not in key and isinstance(val, tf.Tensor)
|
||||||
|
}
|
||||||
|
tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor
|
||||||
|
input_dataset = Dataset.from_dict(tf_inputs_dict)
|
||||||
|
tf_dataset = model.prepare_tf_dataset(
|
||||||
|
input_dataset, batch_size=len(input_dataset), drop_remainder=False, shuffle=False
|
||||||
|
)
|
||||||
|
test_batch = next(iter(tf_dataset))
|
||||||
|
if isinstance(test_batch, tf.Tensor):
|
||||||
|
self.assertEqual(len(test_batch), len(input_dataset)) # Assert we didn't lose any data
|
||||||
|
else:
|
||||||
|
# Assert we discarded the unwanted extra column but kept everything else
|
||||||
|
self.assertEqual(len(test_batch), len(input_dataset.features) - 1)
|
||||||
|
self.assertNotIn("extra_unwanted_column", test_batch)
|
||||||
|
for tensor in test_batch.values():
|
||||||
|
self.assertTrue(isinstance(tensor, tf.Tensor))
|
||||||
|
self.assertEqual(len(tensor), len(input_dataset)) # Assert we didn't lose any data
|
||||||
|
model(test_batch, training=False)
|
||||||
|
|
||||||
|
if "labels" in inspect.signature(model_class.call).parameters.keys():
|
||||||
|
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
if "labels" not in tf_inputs_dict:
|
||||||
|
return # This model isn't giving us labels after all, don't try training with it
|
||||||
|
tf_inputs_dict = {key: val for key, val in tf_inputs_dict.items() if "head_mask" not in key}
|
||||||
|
tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor
|
||||||
|
input_dataset = Dataset.from_dict(tf_inputs_dict)
|
||||||
|
tf_dataset = model.prepare_tf_dataset(
|
||||||
|
input_dataset, batch_size=len(input_dataset), drop_remainder=False, shuffle=False
|
||||||
|
)
|
||||||
|
test_batch, test_batch_labels = next(iter(tf_dataset))
|
||||||
|
self.assertGreater(len(test_batch_labels), 0) # Assert the labels are present
|
||||||
|
feature_columns = 1 if isinstance(test_batch, tf.Tensor) else len(test_batch)
|
||||||
|
label_columns = 1 if isinstance(test_batch_labels, tf.Tensor) else len(test_batch_labels)
|
||||||
|
# Assert we discarded the unwanted extra column but kept everything else
|
||||||
|
self.assertEqual(feature_columns + label_columns, len(input_dataset.features) - 1)
|
||||||
|
if isinstance(test_batch, dict):
|
||||||
|
self.assertNotIn("extra_unwanted_column", test_batch)
|
||||||
|
if isinstance(test_batch_labels, dict):
|
||||||
|
self.assertNotIn("extra_unwanted_column", test_batch_labels)
|
||||||
|
model.compile(optimizer="sgd", run_eagerly=True)
|
||||||
|
model.train_on_batch(test_batch, test_batch_labels)
|
||||||
|
|
||||||
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
||||||
# special tokens cannot be bad tokens
|
# special tokens cannot be bad tokens
|
||||||
special_tokens = []
|
special_tokens = []
|
||||||
|
|||||||
Reference in New Issue
Block a user