simplified model and configuration
This commit is contained in:
committed by
Julien Chaumond
parent
3a9a9f7861
commit
a1994a71ee
@@ -33,17 +33,6 @@ class BertAbsConfig(PretrainedConfig):
|
|||||||
r""" Class to store the configuration of the BertAbs model.
|
r""" Class to store the configuration of the BertAbs model.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
temp_dir: string
|
|
||||||
Unused in the current situation. Kept for compatibility but will be removed.
|
|
||||||
finetune_bert: bool
|
|
||||||
Whether to fine-tune the model or not. Will be kept for reference
|
|
||||||
in case we want to add the possibility to fine-tune the model.
|
|
||||||
large: bool
|
|
||||||
Whether to use bert-large as a base.
|
|
||||||
share_emb: book
|
|
||||||
Whether the embeddings are shared between the encoder and decoder.
|
|
||||||
encoder: string
|
|
||||||
Not clear what this does. Leave to "bert" for pre-trained weights.
|
|
||||||
max_pos: int
|
max_pos: int
|
||||||
The maximum sequence length that this model will be used with.
|
The maximum sequence length that this model will be used with.
|
||||||
enc_layer: int
|
enc_layer: int
|
||||||
@@ -77,11 +66,6 @@ class BertAbsConfig(PretrainedConfig):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size_or_config_json_file=30522,
|
vocab_size_or_config_json_file=30522,
|
||||||
temp_dir=".",
|
|
||||||
finetune_bert=False,
|
|
||||||
large=False,
|
|
||||||
share_emb=True,
|
|
||||||
encoder="bert",
|
|
||||||
max_pos=512,
|
max_pos=512,
|
||||||
enc_layers=6,
|
enc_layers=6,
|
||||||
enc_hidden_size=512,
|
enc_hidden_size=512,
|
||||||
@@ -104,21 +88,15 @@ class BertAbsConfig(PretrainedConfig):
|
|||||||
for key, value in json_config.items():
|
for key, value in json_config.items():
|
||||||
self.__dict__[key] = value
|
self.__dict__[key] = value
|
||||||
elif isinstance(vocab_size_or_config_json_file, int):
|
elif isinstance(vocab_size_or_config_json_file, int):
|
||||||
self.temp_dir = temp_dir
|
|
||||||
self.finetune_bert = finetune_bert
|
|
||||||
self.large = large
|
|
||||||
self.vocab_size = vocab_size_or_config_json_file
|
self.vocab_size = vocab_size_or_config_json_file
|
||||||
self.max_pos = max_pos
|
self.max_pos = max_pos
|
||||||
|
|
||||||
self.encoder = encoder
|
|
||||||
self.enc_layers = enc_layers
|
self.enc_layers = enc_layers
|
||||||
self.enc_hidden_size = enc_hidden_size
|
self.enc_hidden_size = enc_hidden_size
|
||||||
self.enc_heads = enc_heads
|
self.enc_heads = enc_heads
|
||||||
self.enc_ff_size = enc_ff_size
|
self.enc_ff_size = enc_ff_size
|
||||||
self.enc_dropout = enc_dropout
|
self.enc_dropout = enc_dropout
|
||||||
|
|
||||||
self.share_emb = share_emb
|
|
||||||
|
|
||||||
self.dec_layers = dec_layers
|
self.dec_layers = dec_layers
|
||||||
self.dec_hidden_size = dec_hidden_size
|
self.dec_hidden_size = dec_hidden_size
|
||||||
self.dec_heads = dec_heads
|
self.dec_heads = dec_heads
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class BertAbs(BertAbsPreTrainedModel):
|
|||||||
def __init__(self, args, checkpoint=None, bert_extractive_checkpoint=None):
|
def __init__(self, args, checkpoint=None, bert_extractive_checkpoint=None):
|
||||||
super(BertAbs, self).__init__(args)
|
super(BertAbs, self).__init__(args)
|
||||||
self.args = args
|
self.args = args
|
||||||
self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)
|
self.bert = Bert()
|
||||||
|
|
||||||
# If pre-trained weights are passed for Bert, load these.
|
# If pre-trained weights are passed for Bert, load these.
|
||||||
load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False
|
load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False
|
||||||
@@ -69,18 +69,6 @@ class BertAbs(BertAbsPreTrainedModel):
|
|||||||
strict=True,
|
strict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.encoder == "baseline":
|
|
||||||
bert_config = BertConfig(
|
|
||||||
self.bert.model.config.vocab_size,
|
|
||||||
hidden_size=args.enc_hidden_size,
|
|
||||||
num_hidden_layers=args.enc_layers,
|
|
||||||
num_attention_heads=8,
|
|
||||||
intermediate_size=args.enc_ff_size,
|
|
||||||
hidden_dropout_prob=args.enc_dropout,
|
|
||||||
attention_probs_dropout_prob=args.enc_dropout,
|
|
||||||
)
|
|
||||||
self.bert.model = BertModel(bert_config)
|
|
||||||
|
|
||||||
self.vocab_size = self.bert.model.config.vocab_size
|
self.vocab_size = self.bert.model.config.vocab_size
|
||||||
|
|
||||||
if args.max_pos > 512:
|
if args.max_pos > 512:
|
||||||
@@ -101,7 +89,7 @@ class BertAbs(BertAbsPreTrainedModel):
|
|||||||
tgt_embeddings = nn.Embedding(
|
tgt_embeddings = nn.Embedding(
|
||||||
self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0
|
self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0
|
||||||
)
|
)
|
||||||
if self.args.share_emb:
|
|
||||||
tgt_embeddings.weight = copy.deepcopy(
|
tgt_embeddings.weight = copy.deepcopy(
|
||||||
self.bert.model.embeddings.word_embeddings.weight
|
self.bert.model.embeddings.word_embeddings.weight
|
||||||
)
|
)
|
||||||
@@ -141,16 +129,6 @@ class BertAbs(BertAbsPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
p.data.zero_()
|
p.data.zero_()
|
||||||
|
|
||||||
def maybe_tie_embeddings(self, args):
|
|
||||||
if args.use_bert_emb:
|
|
||||||
tgt_embeddings = nn.Embedding(
|
|
||||||
self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0
|
|
||||||
)
|
|
||||||
tgt_embeddings.weight = copy.deepcopy(
|
|
||||||
self.bert.model.embeddings.word_embeddings.weight
|
|
||||||
)
|
|
||||||
self.decoder.embeddings = tgt_embeddings
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
encoder_input_ids,
|
encoder_input_ids,
|
||||||
@@ -178,14 +156,9 @@ class Bert(nn.Module):
|
|||||||
""" This class is not really necessary and should probably disappear.
|
""" This class is not really necessary and should probably disappear.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, large, temp_dir, finetune=False):
|
def __init__(self):
|
||||||
super(Bert, self).__init__()
|
super(Bert, self).__init__()
|
||||||
if large:
|
self.model = BertModel.from_pretrained("bert-base-uncased")
|
||||||
self.model = BertModel.from_pretrained("bert-large-uncased", cache_dir=temp_dir)
|
|
||||||
else:
|
|
||||||
self.model = BertModel.from_pretrained("bert-base-uncased", cache_dir=temp_dir)
|
|
||||||
|
|
||||||
self.finetune = finetune
|
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs):
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs):
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|||||||
@@ -31,9 +31,9 @@ Batch = namedtuple(
|
|||||||
|
|
||||||
def evaluate(args):
|
def evaluate(args):
|
||||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
|
||||||
model = bertabs = BertAbs.from_pretrained("bertabs-finetuned-cnndm")
|
model = BertAbs.from_pretrained("bertabs-finetuned-cnndm")
|
||||||
bertabs.to(args.device)
|
model.to(args.device)
|
||||||
bertabs.eval()
|
model.eval()
|
||||||
|
|
||||||
symbols = {
|
symbols = {
|
||||||
"BOS": tokenizer.vocab["[unused0]"],
|
"BOS": tokenizer.vocab["[unused0]"],
|
||||||
|
|||||||
Reference in New Issue
Block a user