diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 95cbb7c163..c2a0f8c08e 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -28,6 +28,7 @@ import pickle import random import re import shutil +from typing import Tuple import numpy as np import torch @@ -53,6 +54,7 @@ from transformers import ( OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, + PreTrainedTokenizer, RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, @@ -164,7 +166,7 @@ def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False): shutil.rmtree(checkpoint) -def mask_tokens(inputs, tokenizer, args): +def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]: """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ labels = inputs.clone() # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)