resolve PR comments

This commit is contained in:
Rémi Louf
2019-10-29 17:10:20 +01:00
parent 4c3ac4a7d8
commit dfce409691
7 changed files with 647 additions and 473 deletions

View File

@@ -26,189 +26,220 @@ import torch
from torch import nn
class ModelWithBeamSearch(nn.Module):
class TransformerBeamSearch(nn.Module):
def __init__(
self,
model,
tokenizer,
batch_size,
beam_size,
start_token_id,
end_token_id,
pad_token_id,
min_length,
max_length,
alpha,
block_trigram=True,
alpha=0,
block_repeating_trigram=True,
):
"""
Attributes:
mask_word_id: token id that corresponds to the mask
"""
super(ModelWithBeamSearch, self).__init__()
super(TransformerBeamSearch, self).__init__()
self.model = model
self.tokenizer = tokenizer
self.start_token_id = tokenizer.start_token_id
self.end_token_id = tokenizer.end_token_id
self.pad_token_id = tokenizer.pad_token_id
self.beam_size = beam_size
self.start_token_id = start_token_id
self.end_token_id = end_token_id
self.pad_token_id = pad_token_id
self.min_length = min_length
self.max_length = max_length
self.alpha = alpha
self.block_trigram = block_trigram
def forward(self, input_ids, **kwargs):
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
self.block_repeating_trigram = block_repeating_trigram
self.apply_length_penalty = False if alpha == 0 else True
self.alpha = alpha
# State of the beam
self.hypotheses = [[] for _ in range(batch_size)]
self.batch_offset = torch.arange(batch_size, dtype=torch.long)
self.beam_offset = torch.arange(
0, batch_size * self.beam_size, step=self.beam_size, dtype=torch.long
)
self.growing_beam = torch.full(
(batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long
)
self.topk_log_probabilities = torch.tensor(
[0.0] + [float("-inf")] * (self.beam_size - 1), dtype=torch.float
).repeat(batch_size)
self.results = {
"prediction": [[] for _ in batch_size],
"scores": [[] for _ in batch_size],
}
self._step = 0
self.is_done = False
def step(self, log_probabilities):
""" Grows the beam by one step. """
self._step += 1
# The batch size changes as some beams finish so we define _B
vocab_size = log_probabilities.size(-1)
_B = log_probabilities.size(0) // self.beam_size
# Multiply each beam probability with the probability of the
# next token (conditioned on the words in the beam).
log_probabilities += self.topk_log_probabilities.view(-1, 1)
self.enforce_min_length(log_probabilities)
if self.block_repeating_trigram:
self.remove_repeating_trigrams(log_probabilities, _B)
# Find the `beam_size` (previous_beam + token) combinations with
# the highest score
topk_log_probabilities, topk_ids = log_probabilities.topk(
log_probabilities.view(_B, self.beam_size * vocab_size),
self.beam_size,
dim=1,
)
# Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends.
topk_scores = topk_log_probabilities / self.length_penalty()
# Retrieve the corresponding respective beam and token id
# topk_token_ids[i] will be added to topk_beam_ids[i]
topk_beam_ids = topk_ids.div(vocab_size)
topk_token_ids = topk_ids.fmod(vocab_size)
# Retrieve the row index of the surviving beams in the original
# view of the log_probabilities tensor
surviving_beams_rows = (topk_beam_ids + self.beam_offset[:_B].view(-1, 1)).view(
-1
)
# Append the last predictions
self.growing_beam = torch.cat(
[
self.growing_beam.index_select(0, surviving_beams_rows),
topk_token_ids.view(-1, 1),
],
1,
)
# Check if any of the beam searches has ended during this
# growth step. Also if top beam (most probable) has ended
# for one element of the batch.
is_finished = topk_token_ids.eq(self.end_token_id)
self.enforce_max_length()
is_top_beam_finished = is_finished[:, 0].eq(1)
# Save the finished searches
if is_finished.any():
predictions = self.growing_beam.view(
-1, self.beam_size, self.growing_beam.size(1)
)
for i in range(is_finished.size(0)):
if is_top_beam_finished[i]:
is_finished[i].fill_(1)
finished_hyp = is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch.
b = self.batch_offset[i]
for j in finished_hyp:
self.hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
# If the batch reached the end, save the best hypotheses
# in terms of length-penalized score.
if is_top_beam_finished[i]:
best_hyp = sorted(
self.hypotheses[b], key=lambda x: x[0], reverse=True
)
best_score, best_prediction = best_hyp[0]
self.results["scores"][b].append(best_score)
self.results["predictions"][b].append(best_prediction)
non_finished = is_top_beam_finished.eq(0).nonzero().view(-1)
if len(non_finished) == 0:
self.is_done = True
# Remove finished batches for the next step.
topk_log_probabilities = topk_log_probabilities.index_select(
0, non_finished
)
self.batch_offset = self.batch_offset.index_select(0, non_finished)
self.growing_beam = predictions.index_select(0, non_finished).view(
-1, self.growing_beam.size(-1)
)
surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished)
return surviving_beams_rows
def forward(self, encoder_input_ids, **kwargs):
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder = {
argument: value
argument[len("encoder_"):]: value
for argument, value in kwargs.items()
if not argument.startswith("decoder_")
if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_"):]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
}
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
}
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
batch_size, _ = input_ids.size(0)
# Variables that keep track of the status of the search
hypotheses = [[] for _ in range(batch_size)]
batch_offset = torch.arange(batch_size, dtype=torch.long)
beam_offset = torch.arange(
0,
batch_size * self.beam_size,
step=self.beam_size,
dtype=torch.long,
)
growing_beam = torch.full(
(batch_size * self.beam_size, 1),
self.start_token_id,
dtype=torch.long,
)
topk_log_probabilities = torch.tensor(
[0.0] + [float("-inf")] * (self.beam_size - 1),
dtype=torch.float,
).repeat(batch_size)
# Forward pass on the encoder
encoder_outputs = self.encoder(input_ids, kwargs_encoder)
# forward pass on the encoder
encoder_outputs = self.model.encoder.forward(encoder_input_ids, kwargs_encoder)
kwargs_decoder["encoder_hidden_states"] = tile(
encoder_outputs, self.beam_size, dim=0
)
results = {}
results["predictions"] = [[] for _ in batch_size]
results["scores"] = [[] for _ in batch_size]
# grow the beam by generating sequences in an autoregressive way
self.growing_beam = torch.full(
(self.batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long
)
for step in range(self.max_length):
decoder_input = growing_beam[:, -1]
outputs = self.decoder(decoder_input, kwargs_decoder)
decoder_input = self.growing_beam[:, -1]
outputs = self.model.decoder(decoder_input, kwargs_decoder)
log_probabilities = torch.nn.functional.log_softmax(outputs[1])
vocab_size = log_probabilities.size(-1)
surviving_beams_rows = self.step(log_probabilities)
if self.is_done:
break
# The batch size changes as some beams finish so we define:
_B = log_probabilities.size(0) // self.beam_size
# Multiply each beam probability with the probability of the
# next token (conditioned on the words in the beam).
log_probabilities += topk_log_probabilities.view(-1, 1)
# if the beam has not attained the minimum required length we
# make the end token arbitrarily unlikely.
if step < self.min_length:
log_probabilities[self.end_token_id] = -1e20
# Remove repeating tri-grams
if(self.args.block_trigram):
if(step + 1 > 3):
for i in range(_B * self.beam_size):
tokens = [t for t in growing_beam[i]]
trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)]
last_trigram = tuple(trigrams[-1])
if last_trigram in trigrams[:-1]:
log_probabilities[i] = -1e20
# Find the `beam_size` (previous_beam + token) combinations with
# the highest score
topk_log_probabilities, topk_ids = log_probabilities.topk(
log_probabilities.view(_B, self.beam_size * vocab_size),
self.beam_size,
dim=1
)
# Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends.
length_penalty = ((5.0 + (step + 1)) / 6.0) ** self.alpha
topk_scores = topk_log_probabilities / length_penalty
# Retrieve the corresponding respective beam and token id
# topk_token_ids[i] will be added to topk_beam_ids[i]
topk_beam_ids = topk_ids.div(vocab_size)
topk_token_ids = topk_ids.fmod(vocab_size)
# Retrieve the row index of the surviving beams in the original
# view of the log_probabilities tensor
surviving_beams_rows = (
topk_beam_ids + beam_offset[:_B].view(-1, 1)
).view(-1)
# Append the last predictions
growing_beam = torch.cat(
[
growing_beam.index_select(0, surviving_beams_rows),
topk_token_ids.view(-1, 1),
],
1,
)
# Check if any of the beam searches has ended during this
# growth step. Also if top beam (most probable) has ended
# for one element of the batch.
is_finished = topk_token_ids.eq(self.end_token_id)
if step + 1 == self.max_length:
is_finished.fill_(1)
is_top_beam_finished = is_finished[:, 0].eq(1)
# Save the finished searches
if is_finished.any():
predictions = growing_beam.view(-1, self.beam_size, growing_beam.size(1))
for i in range(is_finished.size(0)):
if is_top_beam_finished[i]:
is_finished[i].fill_(1)
finished_hyp = is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch.
b = batch_offset[i]
for j in finished_hyp:
hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
# If the batch reached the end, save the best hypotheses
# in terms of length-penalized score.
if is_top_beam_finished[i]:
best_hyp = sorted(
hypotheses[b], key=lambda x: x[0], reverse=True
)
best_score, best_prediction = best_hyp[0]
results["scores"][b].append(best_score)
results["predictions"][b].append(best_prediction)
non_finished = is_top_beam_finished.eq(0).nonzero().view(-1)
if len(non_finished) == 0:
break
# Remove finished batches for the next step.
topk_log_probabilities = topk_log_probabilities.index_select(0, non_finished)
batch_offset = batch_offset.index_select(0, non_finished)
growing_beam = predictions.index_select(0, non_finished).view(
-1, growing_beam.size(-1)
)
# Re-order the state for the next pass
surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished)
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
"encoder_hidden_states"
].index_select(0, surviving_beams_rows)
return results
return self.results
def remove_repeating_trigrams(self, log_probabilities, _B):
if(self._step + 1 > 3):
for i in range(_B * self.beam_size):
tokens = [t for t in self.growing_beam[i]]
trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)]
last_trigram = tuple(trigrams[-1])
if last_trigram in trigrams[:-1]:
log_probabilities[i] = -1e20
def enforce_min_length(self):
if self._step < self.min_length:
self.log_probabilities[self.end_token_id] = -1e20
def enforce_max_length(self):
if self._step + 1 == self.max_length:
self.is_finished.fill_(1)
def length_penalty(self):
return ((5.0 + (self._step + 1)) / 6.0) ** self.alpha
def tile(x, count, dim=0):

View File

@@ -632,6 +632,8 @@ class BertModel(BertPreTrainedModel):
"""
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
@@ -660,12 +662,15 @@ class BertModel(BertPreTrainedModel):
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# If a 2D encoder attention mask is provided for the cross-attention
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if encoder_attention_mask is not None:
encoder_attention_mask = encoder_attention_mask[:, None, None, :]
encoder_attention_mask = encoder_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0
if encoder_attention_mask.dim() == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@@ -687,7 +692,7 @@ class BertModel(BertPreTrainedModel):
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask)
encoder_attention_mask=encoder_extended_attention_mask)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
@@ -788,8 +793,10 @@ class BertForMaskedLM(BertPreTrainedModel):
in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
**masked_lm_loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
**next_token_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Next token prediction loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
@@ -854,13 +861,13 @@ class BertForMaskedLM(BertPreTrainedModel):
if lm_labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores = prediction_scores[:, :-1, :]
lm_labels = lm_labels[:, 1:]
prediction_scores = prediction_scores[:, :-1, :].contiguous()
lm_labels = lm_labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1)
seq2seq_loss = loss_fct(prediction_scores.reshape(-1, self.config.vocab_size), lm_labels.reshape(-1))
outputs = (seq2seq_loss,) + outputs
next_token_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
outputs = (next_token_loss,) + outputs
return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
return outputs # (masked_lm_loss), (next_token_loss), prediction_scores, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,

View File

@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class PreTrainedSeq2seq(nn.Module):
r"""
:class:`~transformers.Seq2seq` is a generic model class that will be
:class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be
instantiated as a Seq2seq model with one of the base model classes of
the library as encoder and (optionally) as decoder when created with
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
@@ -49,8 +49,7 @@ class PreTrainedSeq2seq(nn.Module):
*model_args,
**kwargs
):
r""" Instantiates an encoder and a decoder from one or two base classes
of the library from pre-trained model checkpoints.
r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
@@ -111,35 +110,44 @@ class PreTrainedSeq2seq(nn.Module):
model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
"""
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as a whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder = {
argument: value
argument[len("encoder_"):]: value
for argument, value in kwargs.items()
if not argument.startswith("decoder_")
if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value
argument[len("decoder_"):]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
}
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
}
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs_encoder.pop("encoder_model", None)
encoder = kwargs_encoder.pop("model", None)
if encoder is None:
kwargs_encoder["is_decoder"] = False
encoder = AutoModel.from_pretrained(
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
)
encoder.config.is_decoder = False
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
kwargs_decoder["is_decoder"] = True
decoder = AutoModelWithLMHead.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder
)
decoder.config.is_decoder = True
model = cls(encoder, decoder)
@@ -169,37 +177,60 @@ class PreTrainedSeq2seq(nn.Module):
decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
Indices of decoder input sequence tokens in the vocabulary.
"""
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder = {
argument: value
argument[len("encoder_"):]: value
for argument, value in kwargs.items()
if not argument.startswith("decoder_")
if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value
argument[len("decoder_"):]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
}
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
}
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
# Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0][
-1
] # output of the encoder *stack*
encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state
else:
encoder_outputs = ()
# Decode
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states[None, :, :]
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
return decoder_outputs + encoder_outputs
class Model2Model(PreTrainedSeq2seq):
r"""
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
where both of the encoder and decoder are of the same family. If the
name of or that path to a pretrained model is specified the encoder and
the decoder will be initialized with the pretrained weight (the
cross-attention will be intialized randomly if its weights are not
present).
It is possible to override this behavior and initialize, say, the decoder randomly
by creating it beforehand as follows
config = BertConfig.from_pretrained()
decoder = BertForMaskedLM(config)
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
"""
def __init__(self, *args, **kwargs):
super(Model2Model, self).__init__(*args, **kwargs)
self.tie_weights()
@@ -235,14 +266,10 @@ class Model2Model(PreTrainedSeq2seq):
model = super(Model2Model, cls).from_pretrained(
encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
*args,
**kwargs
)
# Some architectures require for the decoder to be initialized randomly
# before fine-tuning.
if kwargs.get("decoder_initialize_randomly", False):
model.decoder.init_weights()
return model