update docstrings; rename lm_labels to more explicit ltr_lm_labels
This commit is contained in:
@@ -26,7 +26,7 @@ import numpy as np
|
||||
from tqdm import tqdm, trange
|
||||
import torch
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
@@ -283,14 +283,14 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
model.eval()
|
||||
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
|
||||
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, ltr_lm_labels = batch
|
||||
|
||||
source = source.to(args.device)
|
||||
target = target.to(args.device)
|
||||
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
|
||||
encoder_mask = encoder_mask.to(args.device)
|
||||
decoder_mask = decoder_mask.to(args.device)
|
||||
lm_labels = lm_labels.to(args.device)
|
||||
ltr_lm_labels = ltr_lm_labels.to(args.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
@@ -299,7 +299,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
encoder_token_type_ids=encoder_token_type_ids,
|
||||
encoder_attention_mask=encoder_mask,
|
||||
decoder_attention_mask=decoder_mask,
|
||||
decoder_lm_labels=lm_labels,
|
||||
decoder_ltr_lm_labels=ltr_lm_labels,
|
||||
)
|
||||
lm_loss = outputs[0]
|
||||
eval_loss += lm_loss.mean().item()
|
||||
|
||||
Reference in New Issue
Block a user