From 8abfee9ec327aea0005a7ad367639217ca7dd215 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 8 Oct 2019 16:07:25 +0200 Subject: [PATCH] rename Bert2Bert -> Bert2Rnd --- transformers/__init__.py | 2 +- transformers/modeling_bert.py | 7 ++++--- transformers/tests/modeling_bert_test.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/transformers/__init__.py b/transformers/__init__.py index bf302992b2..006ba9ed16 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -64,7 +64,7 @@ if is_torch_available(): BertForMaskedLM, BertForNextSentencePrediction, BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering, - load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, Bert2Bert) + load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, Bert2Rnd) from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 9ce32d808e..258e4c3430 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -1419,7 +1419,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): @add_start_docstrings("Bert encoder-decoder model for sequence generation.", BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) -class Bert2Bert(BertPreTrainedModel): +class Bert2Rnd(BertPreTrainedModel): r""" Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: @@ -1434,7 +1434,8 @@ class Bert2Bert(BertPreTrainedModel): Examples:: tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - model = Bert2Bert.from_pretrained('bert-base-uncased') + model = Bert2Rnd.from_pretrained('bert-base-uncased') + # fine-tuning magic happens here input = tokenizer.encode("Hello, how are you?") outputs = model(input) output_text = tokenize.decode(outputs[0]) @@ -1468,4 +1469,4 @@ class Bert2Bert(BertPreTrainedModel): token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask) - return decoder_outputs + return decoder_outputs[0] diff --git a/transformers/tests/modeling_bert_test.py b/transformers/tests/modeling_bert_test.py index 2a2c3e50ea..24acf565e3 100644 --- a/transformers/tests/modeling_bert_test.py +++ b/transformers/tests/modeling_bert_test.py @@ -29,7 +29,7 @@ if is_torch_available(): from transformers import (BertConfig, BertModel, BertForMaskedLM, BertForNextSentencePrediction, BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, - BertForTokenClassification, BertForMultipleChoice, Bert2Bert) + BertForTokenClassification, BertForMultipleChoice, Bert2Rnd) from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP else: pytestmark = pytest.mark.skip("Require Torch") @@ -257,7 +257,7 @@ class BertModelTest(CommonTestCases.CommonModelTester): def create_and_check_bert2bert(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): config.num_choices = self.num_choices - model = Bert2Bert(config=config) + model = Bert2Rnd(config=config) model.eval() bert2bert_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() bert2bert_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()