fix a bug in eval_batch_retrieval (#9089)
This commit is contained in:
committed by
GitHub
parent
c19d04623e
commit
44c340f45f
@@ -96,7 +96,7 @@ def evaluate_batch_retrieval(args, rag_model, questions):
|
|||||||
)["input_ids"].to(args.device)
|
)["input_ids"].to(args.device)
|
||||||
|
|
||||||
question_enc_outputs = rag_model.rag.question_encoder(retriever_input_ids)
|
question_enc_outputs = rag_model.rag.question_encoder(retriever_input_ids)
|
||||||
question_enc_pool_output = question_enc_outputs.pooler_output
|
question_enc_pool_output = question_enc_outputs[0]
|
||||||
|
|
||||||
result = rag_model.retriever(
|
result = rag_model.retriever(
|
||||||
retriever_input_ids,
|
retriever_input_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user