diff --git a/docs/source/training.rst b/docs/source/training.rst index 03b4107a8d..86c05a9429 100644 --- a/docs/source/training.rst +++ b/docs/source/training.rst @@ -240,11 +240,11 @@ Then we convert everything in big tensors and use the :obj:`tf.data.Dataset.from .. code-block:: python - train_features = {x: tf_train_dataset[x].to_tensor() for x in tokenizer.model_input_names} + train_features = {x: tf_train_dataset[x] for x in tokenizer.model_input_names} train_tf_dataset = tf.data.Dataset.from_tensor_slices((train_features, tf_train_dataset["label"])) train_tf_dataset = train_tf_dataset.shuffle(len(tf_train_dataset)).batch(8) - eval_features = {x: tf_eval_dataset[x].to_tensor() for x in tokenizer.model_input_names} + eval_features = {x: tf_eval_dataset[x] for x in tokenizer.model_input_names} eval_tf_dataset = tf.data.Dataset.from_tensor_slices((eval_features, tf_eval_dataset["label"])) eval_tf_dataset = eval_tf_dataset.batch(8)