forgot bertForPreTraining
This commit is contained in:
@@ -158,6 +158,19 @@ def bertForPreTraining(*args, **kwargs):
|
|||||||
This module comprises the BERT model followed by the two pre-training heads
|
This module comprises the BERT model followed by the two pre-training heads
|
||||||
- the masked language modeling head, and
|
- the masked language modeling head, and
|
||||||
- the next sentence classification head.
|
- the next sentence classification head.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Load the tokenizer
|
||||||
|
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False)
|
||||||
|
# Prepare tokenized input
|
||||||
|
>>> text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
|
||||||
|
>>> tokenized_text = tokenizer.tokenize(text)
|
||||||
|
>>> segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
|
||||||
|
>>> tokens_tensor = torch.tensor([indexed_tokens])
|
||||||
|
>>> segments_tensors = torch.tensor([segments_ids])
|
||||||
|
# Load bertForPreTraining
|
||||||
|
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForPreTraining', 'bert-base-cased')
|
||||||
|
>>> masked_lm_logits_scores, seq_relationship_logits = model(tokens_tensor, segments_tensors)
|
||||||
"""
|
"""
|
||||||
model = BertForPreTraining.from_pretrained(*args, **kwargs)
|
model = BertForPreTraining.from_pretrained(*args, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|||||||
Reference in New Issue
Block a user