rename Bert2Bert -> Bert2Rnd
This commit is contained in:
@@ -64,7 +64,7 @@ if is_torch_available():
|
|||||||
BertForMaskedLM, BertForNextSentencePrediction,
|
BertForMaskedLM, BertForNextSentencePrediction,
|
||||||
BertForSequenceClassification, BertForMultipleChoice,
|
BertForSequenceClassification, BertForMultipleChoice,
|
||||||
BertForTokenClassification, BertForQuestionAnswering,
|
BertForTokenClassification, BertForQuestionAnswering,
|
||||||
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, Bert2Bert)
|
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, Bert2Rnd)
|
||||||
from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel,
|
from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel,
|
||||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
||||||
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|||||||
@@ -1419,7 +1419,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
|||||||
@add_start_docstrings("Bert encoder-decoder model for sequence generation.",
|
@add_start_docstrings("Bert encoder-decoder model for sequence generation.",
|
||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
BERT_INPUTS_DOCSTRING)
|
BERT_INPUTS_DOCSTRING)
|
||||||
class Bert2Bert(BertPreTrainedModel):
|
class Bert2Rnd(BertPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
|
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
@@ -1434,7 +1434,8 @@ class Bert2Bert(BertPreTrainedModel):
|
|||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
model = Bert2Bert.from_pretrained('bert-base-uncased')
|
model = Bert2Rnd.from_pretrained('bert-base-uncased')
|
||||||
|
# fine-tuning magic happens here
|
||||||
input = tokenizer.encode("Hello, how are you?")
|
input = tokenizer.encode("Hello, how are you?")
|
||||||
outputs = model(input)
|
outputs = model(input)
|
||||||
output_text = tokenize.decode(outputs[0])
|
output_text = tokenize.decode(outputs[0])
|
||||||
@@ -1468,4 +1469,4 @@ class Bert2Bert(BertPreTrainedModel):
|
|||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask)
|
||||||
return decoder_outputs
|
return decoder_outputs[0]
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ if is_torch_available():
|
|||||||
from transformers import (BertConfig, BertModel, BertForMaskedLM,
|
from transformers import (BertConfig, BertModel, BertForMaskedLM,
|
||||||
BertForNextSentencePrediction, BertForPreTraining,
|
BertForNextSentencePrediction, BertForPreTraining,
|
||||||
BertForQuestionAnswering, BertForSequenceClassification,
|
BertForQuestionAnswering, BertForSequenceClassification,
|
||||||
BertForTokenClassification, BertForMultipleChoice, Bert2Bert)
|
BertForTokenClassification, BertForMultipleChoice, Bert2Rnd)
|
||||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
else:
|
else:
|
||||||
pytestmark = pytest.mark.skip("Require Torch")
|
pytestmark = pytest.mark.skip("Require Torch")
|
||||||
@@ -257,7 +257,7 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
|||||||
|
|
||||||
def create_and_check_bert2bert(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
def create_and_check_bert2bert(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
config.num_choices = self.num_choices
|
config.num_choices = self.num_choices
|
||||||
model = Bert2Bert(config=config)
|
model = Bert2Rnd(config=config)
|
||||||
model.eval()
|
model.eval()
|
||||||
bert2bert_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
bert2bert_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
bert2bert_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
bert2bert_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
|||||||
Reference in New Issue
Block a user