diff --git a/examples/tensorflow/text-classification/run_text_classification.py b/examples/tensorflow/text-classification/run_text_classification.py index ab4f005ee3..32e020d7bf 100644 --- a/examples/tensorflow/text-classification/run_text_classification.py +++ b/examples/tensorflow/text-classification/run_text_classification.py @@ -522,7 +522,7 @@ def main(): # region Prediction losses # This section is outside the scope() because it's very quick to compute, but behaves badly inside it - if "label" in datasets["test"].features: + if "test" in datasets and "label" in datasets["test"].features: print("Computing prediction loss on test labels...") labels = datasets["test"]["label"] loss = float(loss_fn(labels, predictions).numpy())