From 438f2730a03e19bc21f2823c659ceaed0dfe8ef7 Mon Sep 17 00:00:00 2001 From: altsoph Date: Fri, 25 Oct 2019 13:22:58 +0300 Subject: [PATCH] Evaluation code fixed. --- examples/run_lm_finetuning.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 571bcb4391..4d32385e40 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -86,6 +86,7 @@ class TextDataset(Dataset): # Note that we are loosing the last truncated example here for the sake of simplicity (no padding) # If your dataset is small, first you should loook for a bigger one :-) and second you # can change this behavior by adding (model specific) padding. + self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[-block_size:])) # DIRTY! logger.info("Saving features into cached file %s", cached_features_file) with open(cached_features_file, 'wb') as handle: @@ -309,10 +310,12 @@ def evaluate(args, model, tokenizer, prefix=""): model.eval() for batch in tqdm(eval_dataloader, desc="Evaluating"): - batch = batch.to(args.device) + inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) + inputs = inputs.to(args.device) + labels = labels.to(args.device) with torch.no_grad(): - outputs = model(batch, masked_lm_labels=batch) if args.mlm else model(batch, labels=batch) + outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels) lm_loss = outputs[0] eval_loss += lm_loss.mean().item() nb_eval_steps += 1 @@ -540,4 +543,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file