CLM for BERT, beginning of CLM fot RoBERTa; still needs a better masking token mechanism.
This commit is contained in:
@@ -13,7 +13,11 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Finetuning the library models for language modeling on WikiText-2 (GPT, GPT-2, XLM)."""
|
"""
|
||||||
|
Fine-tuning the library models for language modeling on WikiText-2 (GPT, GPT-2, BERT, RoBERTa).
|
||||||
|
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
|
||||||
|
using a masked language modeling (MLM) loss.
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
@@ -30,8 +34,10 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
|
from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer)
|
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
BertConfig, BertForMaskedLM, BertTokenizer, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||||
|
|
||||||
from utils_lm import WikiTextDataset
|
from utils_lm import WikiTextDataset
|
||||||
@@ -42,7 +48,9 @@ ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (
|
|||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
||||||
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer)
|
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||||
|
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||||
|
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -53,6 +61,18 @@ def set_seed(args):
|
|||||||
if args.n_gpu > 0:
|
if args.n_gpu > 0:
|
||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
# Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original
|
||||||
|
def mask_tokens(inputs, tokenizer, args):
|
||||||
|
labels = inputs.clone()
|
||||||
|
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).byte()
|
||||||
|
labels[~masked_indices] = -1 # We only compute loss on masked tokens
|
||||||
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices
|
||||||
|
inputs[indices_replaced] = tokenizer.vocab["[MASK]"] # 80% of the time, replace masked input tokens with [MASK]
|
||||||
|
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced
|
||||||
|
random_words = torch.randint(args.num_embeddings, labels.shape, dtype=torch.long, device=args.device)
|
||||||
|
inputs[indices_random] = random_words[
|
||||||
|
indices_random] # 10% of the time, replace masked input tokens with random word
|
||||||
|
return inputs, labels
|
||||||
|
|
||||||
def train(args, train_dataset, model, tokenizer):
|
def train(args, train_dataset, model, tokenizer):
|
||||||
""" Train the model """
|
""" Train the model """
|
||||||
@@ -108,13 +128,14 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
tr_loss, logging_loss = 0.0, 0.0
|
tr_loss, logging_loss = 0.0, 0.0
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
|
||||||
for _ in train_iterator:
|
for _ in train_iterator:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||||
for step, batch in enumerate(epoch_iterator):
|
for step, batch in enumerate(epoch_iterator):
|
||||||
batch.to(args.device)
|
batch.to(args.device)
|
||||||
model.train()
|
model.train()
|
||||||
outputs = model(batch, labels=batch)
|
inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
|
||||||
|
outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
|
||||||
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||||
|
|
||||||
if args.n_gpu > 1:
|
if args.n_gpu > 1:
|
||||||
@@ -132,8 +153,8 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
|
|
||||||
tr_loss += loss.item()
|
tr_loss += loss.item()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
scheduler.step() # Update learning rate schedule
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
scheduler.step() # Update learning rate schedule
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
@@ -196,7 +217,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
batch.to(args.device)
|
batch.to(args.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(batch, labels=batch)
|
outputs = model(batch)
|
||||||
lm_loss = outputs[0]
|
lm_loss = outputs[0]
|
||||||
eval_loss += lm_loss.mean().item()
|
eval_loss += lm_loss.mean().item()
|
||||||
nb_eval_steps += 1
|
nb_eval_steps += 1
|
||||||
@@ -236,8 +257,16 @@ def main():
|
|||||||
help="The output directory where the model predictions and checkpoints will be written.")
|
help="The output directory where the model predictions and checkpoints will be written.")
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--model_name_or_path", default="gpt2", type=str,
|
parser.add_argument("--model_name", default="bert", type=str,
|
||||||
help="The model to be fine-tuned.")
|
help="The model architecture to be fine-tuned.")
|
||||||
|
parser.add_argument("--model_checkpoint", default="bert-base-cased", type=str,
|
||||||
|
help="The model checkpoint for weights initialization.")
|
||||||
|
|
||||||
|
parser.add_argument("--mlm", action='store_true',
|
||||||
|
help="Train with masked-language modeling loss instead of language modeling.")
|
||||||
|
parser.add_argument("--mlm_probability", type=float, default=0.15,
|
||||||
|
help="Ratio of tokens to mask for masked language modeling loss")
|
||||||
|
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
parser.add_argument("--config_name", default="", type=str,
|
||||||
help="Pretrained config name or path if not the same as model_name")
|
help="Pretrained config name or path if not the same as model_name")
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||||
@@ -303,6 +332,10 @@ def main():
|
|||||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.model_name in ["bert", "roberta"] and not args.mlm:
|
||||||
|
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
||||||
|
"flag (masked language modeling).")
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||||
|
|
||||||
@@ -339,10 +372,11 @@ def main():
|
|||||||
if args.local_rank not in [-1, 0]:
|
if args.local_rank not in [-1, 0]:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
|
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_name_or_path]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_name]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_checkpoint)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_checkpoint, do_lower_case=args.do_lower_case)
|
||||||
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
model = model_class.from_pretrained(args.model_checkpoint, from_tf=bool('.ckpt' in args.model_checkpoint), config=config)
|
||||||
|
args.num_embeddings = config.vocab_size # We need this to create the model at next line (number of embeddings to use)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
|
|||||||
Reference in New Issue
Block a user