diff --git a/examples/summarization/configuration_bertabs.py b/examples/summarization/configuration_bertabs.py index ff3171f9a8..5bcb65b423 100644 --- a/examples/summarization/configuration_bertabs.py +++ b/examples/summarization/configuration_bertabs.py @@ -33,17 +33,6 @@ class BertAbsConfig(PretrainedConfig): r""" Class to store the configuration of the BertAbs model. Arguments: - temp_dir: string - Unused in the current situation. Kept for compatibility but will be removed. - finetune_bert: bool - Whether to fine-tune the model or not. Will be kept for reference - in case we want to add the possibility to fine-tune the model. - large: bool - Whether to use bert-large as a base. - share_emb: book - Whether the embeddings are shared between the encoder and decoder. - encoder: string - Not clear what this does. Leave to "bert" for pre-trained weights. max_pos: int The maximum sequence length that this model will be used with. enc_layer: int @@ -77,11 +66,6 @@ class BertAbsConfig(PretrainedConfig): def __init__( self, vocab_size_or_config_json_file=30522, - temp_dir=".", - finetune_bert=False, - large=False, - share_emb=True, - encoder="bert", max_pos=512, enc_layers=6, enc_hidden_size=512, @@ -104,21 +88,15 @@ class BertAbsConfig(PretrainedConfig): for key, value in json_config.items(): self.__dict__[key] = value elif isinstance(vocab_size_or_config_json_file, int): - self.temp_dir = temp_dir - self.finetune_bert = finetune_bert - self.large = large self.vocab_size = vocab_size_or_config_json_file self.max_pos = max_pos - self.encoder = encoder self.enc_layers = enc_layers self.enc_hidden_size = enc_hidden_size self.enc_heads = enc_heads self.enc_ff_size = enc_ff_size self.enc_dropout = enc_dropout - self.share_emb = share_emb - self.dec_layers = dec_layers self.dec_hidden_size = dec_hidden_size self.dec_heads = dec_heads diff --git a/examples/summarization/modeling_bertabs.py b/examples/summarization/modeling_bertabs.py index 0189a2ad2b..5e51526037 100644 --- a/examples/summarization/modeling_bertabs.py +++ b/examples/summarization/modeling_bertabs.py @@ -53,7 +53,7 @@ class BertAbs(BertAbsPreTrainedModel): def __init__(self, args, checkpoint=None, bert_extractive_checkpoint=None): super(BertAbs, self).__init__(args) self.args = args - self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) + self.bert = Bert() # If pre-trained weights are passed for Bert, load these. load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False @@ -69,18 +69,6 @@ class BertAbs(BertAbsPreTrainedModel): strict=True, ) - if args.encoder == "baseline": - bert_config = BertConfig( - self.bert.model.config.vocab_size, - hidden_size=args.enc_hidden_size, - num_hidden_layers=args.enc_layers, - num_attention_heads=8, - intermediate_size=args.enc_ff_size, - hidden_dropout_prob=args.enc_dropout, - attention_probs_dropout_prob=args.enc_dropout, - ) - self.bert.model = BertModel(bert_config) - self.vocab_size = self.bert.model.config.vocab_size if args.max_pos > 512: @@ -101,10 +89,10 @@ class BertAbs(BertAbsPreTrainedModel): tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0 ) - if self.args.share_emb: - tgt_embeddings.weight = copy.deepcopy( - self.bert.model.embeddings.word_embeddings.weight - ) + + tgt_embeddings.weight = copy.deepcopy( + self.bert.model.embeddings.word_embeddings.weight + ) self.decoder = TransformerDecoder( self.args.dec_layers, @@ -141,16 +129,6 @@ class BertAbs(BertAbsPreTrainedModel): else: p.data.zero_() - def maybe_tie_embeddings(self, args): - if args.use_bert_emb: - tgt_embeddings = nn.Embedding( - self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0 - ) - tgt_embeddings.weight = copy.deepcopy( - self.bert.model.embeddings.word_embeddings.weight - ) - self.decoder.embeddings = tgt_embeddings - def forward( self, encoder_input_ids, @@ -178,14 +156,9 @@ class Bert(nn.Module): """ This class is not really necessary and should probably disappear. """ - def __init__(self, large, temp_dir, finetune=False): + def __init__(self): super(Bert, self).__init__() - if large: - self.model = BertModel.from_pretrained("bert-large-uncased", cache_dir=temp_dir) - else: - self.model = BertModel.from_pretrained("bert-base-uncased", cache_dir=temp_dir) - - self.finetune = finetune + self.model = BertModel.from_pretrained("bert-base-uncased") def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs): self.eval() diff --git a/examples/summarization/run_summarization.py b/examples/summarization/run_summarization.py index bbc79227ca..ed663e880b 100644 --- a/examples/summarization/run_summarization.py +++ b/examples/summarization/run_summarization.py @@ -31,9 +31,9 @@ Batch = namedtuple( def evaluate(args): tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) - model = bertabs = BertAbs.from_pretrained("bertabs-finetuned-cnndm") - bertabs.to(args.device) - bertabs.eval() + model = BertAbs.from_pretrained("bertabs-finetuned-cnndm") + model.to(args.device) + model.eval() symbols = { "BOS": tokenizer.vocab["[unused0]"],