From 1e68c28670cc8d0e8d20ca9fadc697f03908015b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 10 Oct 2019 18:07:11 +0200 Subject: [PATCH] add test for initialization of Bert2Rnd --- examples/run_summarization.py | 49 ++++++++++++++++++++++++ transformers/tests/modeling_bert_test.py | 12 +++--- 2 files changed, 55 insertions(+), 6 deletions(-) create mode 100644 examples/run_summarization.py diff --git a/examples/run_summarization.py b/examples/run_summarization.py new file mode 100644 index 0000000000..0a367551d6 --- /dev/null +++ b/examples/run_summarization.py @@ -0,0 +1,49 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Finetuning seq2seq models for abstractive summarization. + +The finetuning method for abstractive summarization is inspired by [1]. We +concatenate the document and summary, mask words of the summary at random and +maximizing the likelihood of masked words. + +[1] Dong Li, Nan Yang, Wenhui Wang, Furu Wei, Xiaodong Liu, Yu Wang, Jianfeng +Gao, Ming Zhou, and Hsiao-Wuen Hon. “Unified Language Model Pre-Training for +Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197 +""" + +import logging +import random + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def train(args, train_dataset, model, tokenizer): + raise NotImplementedError + + +def evaluate(args, model, tokenizer, prefix=""): + raise NotImplementedError diff --git a/transformers/tests/modeling_bert_test.py b/transformers/tests/modeling_bert_test.py index fe9e039983..e649cd8ce8 100644 --- a/transformers/tests/modeling_bert_test.py +++ b/transformers/tests/modeling_bert_test.py @@ -259,12 +259,12 @@ class BertModelTest(CommonTestCases.CommonModelTester): config.num_choices = self.num_choices model = Bert2Rnd(config=config) model.eval() - bert2bert_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() - bert2bert_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() - bert2bert_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() - _ = model(bert2bert_inputs_ids, - attention_mask=bert2bert_input_mask, - token_type_ids=bert2bert_token_type_ids) + bert2rnd_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + bert2rnd_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + bert2rnd_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + _ = model(bert2rnd_inputs_ids, + attention_mask=bert2rnd_input_mask, + token_type_ids=bert2rnd_token_type_ids) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs()