[run_lm_finetuning] mask_tokens: document types
This commit is contained in:
@@ -28,6 +28,7 @@ import pickle
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -53,6 +54,7 @@ from transformers import (
|
|||||||
OpenAIGPTConfig,
|
OpenAIGPTConfig,
|
||||||
OpenAIGPTLMHeadModel,
|
OpenAIGPTLMHeadModel,
|
||||||
OpenAIGPTTokenizer,
|
OpenAIGPTTokenizer,
|
||||||
|
PreTrainedTokenizer,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
RobertaTokenizer,
|
RobertaTokenizer,
|
||||||
@@ -164,7 +166,7 @@ def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
|
|||||||
shutil.rmtree(checkpoint)
|
shutil.rmtree(checkpoint)
|
||||||
|
|
||||||
|
|
||||||
def mask_tokens(inputs, tokenizer, args):
|
def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
||||||
labels = inputs.clone()
|
labels = inputs.clone()
|
||||||
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
||||||
|
|||||||
Reference in New Issue
Block a user