Fix BERT example code for NSP and Multiple Choice (#3953)
Change the example code to use encode_plus since the token_type_id wasn't being correctly set.
This commit is contained in:
@@ -1036,11 +1036,12 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
|||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
|
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
|
||||||
|
|
||||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||||
outputs = model(input_ids)
|
next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
||||||
|
encoding = tokenizer.encode_plus(prompt, next_sentence, return_tensors='pt')
|
||||||
seq_relationship_scores = outputs[0]
|
|
||||||
|
|
||||||
|
loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1]))
|
||||||
|
assert logits[0, 0] < logits[0, 1] # next sentence was random
|
||||||
"""
|
"""
|
||||||
|
|
||||||
outputs = self.bert(
|
outputs = self.bert(
|
||||||
@@ -1191,7 +1192,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
r"""
|
r"""
|
||||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
Labels for computing the multiple choice classification loss.
|
Labels for computing the multiple choice classification loss.
|
||||||
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
|
||||||
of the input tensors. (see `input_ids` above)
|
of the input tensors. (see `input_ids` above)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1221,14 +1222,17 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
|
model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
|
||||||
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
|
||||||
|
|
||||||
input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||||
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
|
choice0 = "It is eaten with a fork and a knife."
|
||||||
outputs = model(input_ids, labels=labels)
|
choice1 = "It is eaten while held in the hand."
|
||||||
|
labels = torch.tensor(0) # choice0 is correct (according to Wikipedia ;))
|
||||||
|
|
||||||
loss, classification_scores = outputs[:2]
|
encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='pt', pad_to_max_length=True)
|
||||||
|
outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1
|
||||||
|
|
||||||
|
# the linear classifier still needs to be trained
|
||||||
|
loss, logits = outputs[:2]
|
||||||
"""
|
"""
|
||||||
num_choices = input_ids.shape[1]
|
num_choices = input_ids.shape[1]
|
||||||
|
|
||||||
|
|||||||
@@ -857,10 +857,13 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
model = TFBertForNextSentencePrediction.from_pretrained('bert-base-uncased')
|
model = TFBertForNextSentencePrediction.from_pretrained('bert-base-uncased')
|
||||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
|
||||||
outputs = model(input_ids)
|
|
||||||
seq_relationship_scores = outputs[0]
|
|
||||||
|
|
||||||
|
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||||
|
next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
||||||
|
encoding = tokenizer.encode_plus(prompt, next_sentence, return_tensors='tf')
|
||||||
|
|
||||||
|
logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
|
||||||
|
assert logits[0][0] < logits[0][1] # the next sentence was random
|
||||||
"""
|
"""
|
||||||
outputs = self.bert(inputs, **kwargs)
|
outputs = self.bert(inputs, **kwargs)
|
||||||
|
|
||||||
@@ -990,11 +993,15 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
|
|||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
model = TFBertForMultipleChoice.from_pretrained('bert-base-uncased')
|
model = TFBertForMultipleChoice.from_pretrained('bert-base-uncased')
|
||||||
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
|
||||||
input_ids = tf.constant([tokenizer.encode(s) for s in choices])[None, :] # Batch size 1, 2 choices
|
|
||||||
outputs = model(input_ids)
|
|
||||||
classification_scores = outputs[0]
|
|
||||||
|
|
||||||
|
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||||
|
choice0 = "It is eaten with a fork and a knife."
|
||||||
|
choice1 = "It is eaten while held in the hand."
|
||||||
|
encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='tf', pad_to_max_length=True)
|
||||||
|
|
||||||
|
# linear classifier on the output is not yet trained
|
||||||
|
outputs = model(encoding['input_ids'][None, :])
|
||||||
|
logits = outputs[0]
|
||||||
"""
|
"""
|
||||||
if isinstance(inputs, (tuple, list)):
|
if isinstance(inputs, (tuple, list)):
|
||||||
input_ids = inputs[0]
|
input_ids = inputs[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user