update docstrings; rename lm_labels to more explicit ltr_lm_labels

This commit is contained in:
Rémi Louf
2019-10-29 20:08:03 +01:00
parent dfce409691
commit 098a89f312
2 changed files with 32 additions and 27 deletions

View File

@@ -26,7 +26,7 @@ import numpy as np
from tqdm import tqdm, trange
import torch
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import (
AutoTokenizer,
@@ -283,14 +283,14 @@ def evaluate(args, model, tokenizer, prefix=""):
model.eval()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, ltr_lm_labels = batch
source = source.to(args.device)
target = target.to(args.device)
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
encoder_mask = encoder_mask.to(args.device)
decoder_mask = decoder_mask.to(args.device)
lm_labels = lm_labels.to(args.device)
ltr_lm_labels = ltr_lm_labels.to(args.device)
with torch.no_grad():
outputs = model(
@@ -299,7 +299,7 @@ def evaluate(args, model, tokenizer, prefix=""):
encoder_token_type_ids=encoder_token_type_ids,
encoder_attention_mask=encoder_mask,
decoder_attention_mask=decoder_mask,
decoder_lm_labels=lm_labels,
decoder_ltr_lm_labels=ltr_lm_labels,
)
lm_loss = outputs[0]
eval_loss += lm_loss.mean().item()