correct attention mask (#7373)

This commit is contained in:
Patrick von Platen
2020-09-24 23:22:04 +02:00
committed by GitHub
parent a8cbc4269c
commit 0804d077c6
3 changed files with 30 additions and 12 deletions

View File

@@ -115,11 +115,15 @@ def evaluate_batch_retrieval(args, rag_model, questions):
def evaluate_batch_e2e(args, rag_model, questions):
with torch.no_grad():
input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
questions, return_tensors="pt", padding=True, truncation=True
)["input_ids"].to(args.device)
)
input_ids = inputs_dict.input_ids.to(args.device)
attention_mask = inputs_dict.attention_mask.to(args.device)
outputs = rag_model.generate( # rag_model overwrites generate
input_ids,
attention_mask=attention_mask,
num_beams=args.num_beams,
min_length=args.min_length,
max_length=args.max_length,