diff --git a/examples/rag/eval_rag.py b/examples/rag/eval_rag.py index baa956ecab..73913c1acd 100644 --- a/examples/rag/eval_rag.py +++ b/examples/rag/eval_rag.py @@ -72,7 +72,7 @@ def get_precision_at_k(args, preds_path, gold_data_path): em = total = 0 for hypo, reference in zip(hypos, references): hypo_provenance = set(hypo.split("\t")[:k]) - ref_provenance = set(reference.split("\t")[1 : (k + 1)]) + ref_provenance = set(reference.split("\t")) total += 1 em += len(hypo_provenance & ref_provenance) / k