From 84b9579da70d2195774f072644dc1c4a2f1e2344 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 25 Oct 2021 15:04:36 +0100 Subject: [PATCH] Remove unneeded `to_tensor()` in TF inline example (#14140) --- docs/source/training.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/training.rst b/docs/source/training.rst index 03b4107a8d..86c05a9429 100644 --- a/docs/source/training.rst +++ b/docs/source/training.rst @@ -240,11 +240,11 @@ Then we convert everything in big tensors and use the :obj:`tf.data.Dataset.from .. code-block:: python - train_features = {x: tf_train_dataset[x].to_tensor() for x in tokenizer.model_input_names} + train_features = {x: tf_train_dataset[x] for x in tokenizer.model_input_names} train_tf_dataset = tf.data.Dataset.from_tensor_slices((train_features, tf_train_dataset["label"])) train_tf_dataset = train_tf_dataset.shuffle(len(tf_train_dataset)).batch(8) - eval_features = {x: tf_eval_dataset[x].to_tensor() for x in tokenizer.model_input_names} + eval_features = {x: tf_eval_dataset[x] for x in tokenizer.model_input_names} eval_tf_dataset = tf.data.Dataset.from_tensor_slices((eval_features, tf_eval_dataset["label"])) eval_tf_dataset = eval_tf_dataset.batch(8)