simplified model and configuration
This commit is contained in:
committed by
Julien Chaumond
parent
3a9a9f7861
commit
a1994a71ee
@@ -53,7 +53,7 @@ class BertAbs(BertAbsPreTrainedModel):
|
||||
def __init__(self, args, checkpoint=None, bert_extractive_checkpoint=None):
|
||||
super(BertAbs, self).__init__(args)
|
||||
self.args = args
|
||||
self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)
|
||||
self.bert = Bert()
|
||||
|
||||
# If pre-trained weights are passed for Bert, load these.
|
||||
load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False
|
||||
@@ -69,18 +69,6 @@ class BertAbs(BertAbsPreTrainedModel):
|
||||
strict=True,
|
||||
)
|
||||
|
||||
if args.encoder == "baseline":
|
||||
bert_config = BertConfig(
|
||||
self.bert.model.config.vocab_size,
|
||||
hidden_size=args.enc_hidden_size,
|
||||
num_hidden_layers=args.enc_layers,
|
||||
num_attention_heads=8,
|
||||
intermediate_size=args.enc_ff_size,
|
||||
hidden_dropout_prob=args.enc_dropout,
|
||||
attention_probs_dropout_prob=args.enc_dropout,
|
||||
)
|
||||
self.bert.model = BertModel(bert_config)
|
||||
|
||||
self.vocab_size = self.bert.model.config.vocab_size
|
||||
|
||||
if args.max_pos > 512:
|
||||
@@ -101,10 +89,10 @@ class BertAbs(BertAbsPreTrainedModel):
|
||||
tgt_embeddings = nn.Embedding(
|
||||
self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0
|
||||
)
|
||||
if self.args.share_emb:
|
||||
tgt_embeddings.weight = copy.deepcopy(
|
||||
self.bert.model.embeddings.word_embeddings.weight
|
||||
)
|
||||
|
||||
tgt_embeddings.weight = copy.deepcopy(
|
||||
self.bert.model.embeddings.word_embeddings.weight
|
||||
)
|
||||
|
||||
self.decoder = TransformerDecoder(
|
||||
self.args.dec_layers,
|
||||
@@ -141,16 +129,6 @@ class BertAbs(BertAbsPreTrainedModel):
|
||||
else:
|
||||
p.data.zero_()
|
||||
|
||||
def maybe_tie_embeddings(self, args):
|
||||
if args.use_bert_emb:
|
||||
tgt_embeddings = nn.Embedding(
|
||||
self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0
|
||||
)
|
||||
tgt_embeddings.weight = copy.deepcopy(
|
||||
self.bert.model.embeddings.word_embeddings.weight
|
||||
)
|
||||
self.decoder.embeddings = tgt_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_input_ids,
|
||||
@@ -178,14 +156,9 @@ class Bert(nn.Module):
|
||||
""" This class is not really necessary and should probably disappear.
|
||||
"""
|
||||
|
||||
def __init__(self, large, temp_dir, finetune=False):
|
||||
def __init__(self):
|
||||
super(Bert, self).__init__()
|
||||
if large:
|
||||
self.model = BertModel.from_pretrained("bert-large-uncased", cache_dir=temp_dir)
|
||||
else:
|
||||
self.model = BertModel.from_pretrained("bert-base-uncased", cache_dir=temp_dir)
|
||||
|
||||
self.finetune = finetune
|
||||
self.model = BertModel.from_pretrained("bert-base-uncased")
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs):
|
||||
self.eval()
|
||||
|
||||
Reference in New Issue
Block a user