correct attention mask (#7373)
This commit is contained in:
committed by
GitHub
parent
a8cbc4269c
commit
0804d077c6
@@ -115,11 +115,15 @@ def evaluate_batch_retrieval(args, rag_model, questions):
|
||||
|
||||
def evaluate_batch_e2e(args, rag_model, questions):
|
||||
with torch.no_grad():
|
||||
input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
||||
inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
||||
questions, return_tensors="pt", padding=True, truncation=True
|
||||
)["input_ids"].to(args.device)
|
||||
)
|
||||
|
||||
input_ids = inputs_dict.input_ids.to(args.device)
|
||||
attention_mask = inputs_dict.attention_mask.to(args.device)
|
||||
outputs = rag_model.generate( # rag_model overwrites generate
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
num_beams=args.num_beams,
|
||||
min_length=args.min_length,
|
||||
max_length=args.max_length,
|
||||
|
||||
Reference in New Issue
Block a user