revert renaming of lm_labels to ltr_lm_labels

This commit is contained in:
Rémi Louf
2019-10-30 10:43:13 +01:00
parent 098a89f312
commit 9c1bdb5b61
3 changed files with 23 additions and 19 deletions

View File

@@ -283,14 +283,14 @@ def evaluate(args, model, tokenizer, prefix=""):
model.eval()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, ltr_lm_labels = batch
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
source = source.to(args.device)
target = target.to(args.device)
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
encoder_mask = encoder_mask.to(args.device)
decoder_mask = decoder_mask.to(args.device)
ltr_lm_labels = ltr_lm_labels.to(args.device)
lm_labels = lm_labels.to(args.device)
with torch.no_grad():
outputs = model(
@@ -299,7 +299,7 @@ def evaluate(args, model, tokenizer, prefix=""):
encoder_token_type_ids=encoder_token_type_ids,
encoder_attention_mask=encoder_mask,
decoder_attention_mask=decoder_mask,
decoder_ltr_lm_labels=ltr_lm_labels,
decoder_lm_labels=lm_labels,
)
lm_loss = outputs[0]
eval_loss += lm_loss.mean().item()