resolve PR comments
This commit is contained in:
@@ -26,189 +26,220 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ModelWithBeamSearch(nn.Module):
|
||||
class TransformerBeamSearch(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
beam_size,
|
||||
start_token_id,
|
||||
end_token_id,
|
||||
pad_token_id,
|
||||
min_length,
|
||||
max_length,
|
||||
alpha,
|
||||
block_trigram=True,
|
||||
alpha=0,
|
||||
block_repeating_trigram=True,
|
||||
):
|
||||
"""
|
||||
Attributes:
|
||||
mask_word_id: token id that corresponds to the mask
|
||||
"""
|
||||
super(ModelWithBeamSearch, self).__init__()
|
||||
super(TransformerBeamSearch, self).__init__()
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.start_token_id = tokenizer.start_token_id
|
||||
self.end_token_id = tokenizer.end_token_id
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
self.beam_size = beam_size
|
||||
self.start_token_id = start_token_id
|
||||
self.end_token_id = end_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.min_length = min_length
|
||||
self.max_length = max_length
|
||||
self.alpha = alpha
|
||||
self.block_trigram = block_trigram
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||
# decoder-specific it the key starts with `decoder_`
|
||||
self.block_repeating_trigram = block_repeating_trigram
|
||||
self.apply_length_penalty = False if alpha == 0 else True
|
||||
self.alpha = alpha
|
||||
|
||||
# State of the beam
|
||||
self.hypotheses = [[] for _ in range(batch_size)]
|
||||
self.batch_offset = torch.arange(batch_size, dtype=torch.long)
|
||||
self.beam_offset = torch.arange(
|
||||
0, batch_size * self.beam_size, step=self.beam_size, dtype=torch.long
|
||||
)
|
||||
self.growing_beam = torch.full(
|
||||
(batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long
|
||||
)
|
||||
self.topk_log_probabilities = torch.tensor(
|
||||
[0.0] + [float("-inf")] * (self.beam_size - 1), dtype=torch.float
|
||||
).repeat(batch_size)
|
||||
self.results = {
|
||||
"prediction": [[] for _ in batch_size],
|
||||
"scores": [[] for _ in batch_size],
|
||||
}
|
||||
self._step = 0
|
||||
self.is_done = False
|
||||
|
||||
def step(self, log_probabilities):
|
||||
""" Grows the beam by one step. """
|
||||
self._step += 1
|
||||
|
||||
# The batch size changes as some beams finish so we define _B
|
||||
vocab_size = log_probabilities.size(-1)
|
||||
_B = log_probabilities.size(0) // self.beam_size
|
||||
|
||||
# Multiply each beam probability with the probability of the
|
||||
# next token (conditioned on the words in the beam).
|
||||
log_probabilities += self.topk_log_probabilities.view(-1, 1)
|
||||
|
||||
self.enforce_min_length(log_probabilities)
|
||||
if self.block_repeating_trigram:
|
||||
self.remove_repeating_trigrams(log_probabilities, _B)
|
||||
|
||||
# Find the `beam_size` (previous_beam + token) combinations with
|
||||
# the highest score
|
||||
topk_log_probabilities, topk_ids = log_probabilities.topk(
|
||||
log_probabilities.view(_B, self.beam_size * vocab_size),
|
||||
self.beam_size,
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Apply the length penalty. The +1 accounts for the [EOS] token
|
||||
# that will be added if the beam ends.
|
||||
topk_scores = topk_log_probabilities / self.length_penalty()
|
||||
|
||||
# Retrieve the corresponding respective beam and token id
|
||||
# topk_token_ids[i] will be added to topk_beam_ids[i]
|
||||
topk_beam_ids = topk_ids.div(vocab_size)
|
||||
topk_token_ids = topk_ids.fmod(vocab_size)
|
||||
|
||||
# Retrieve the row index of the surviving beams in the original
|
||||
# view of the log_probabilities tensor
|
||||
surviving_beams_rows = (topk_beam_ids + self.beam_offset[:_B].view(-1, 1)).view(
|
||||
-1
|
||||
)
|
||||
|
||||
# Append the last predictions
|
||||
self.growing_beam = torch.cat(
|
||||
[
|
||||
self.growing_beam.index_select(0, surviving_beams_rows),
|
||||
topk_token_ids.view(-1, 1),
|
||||
],
|
||||
1,
|
||||
)
|
||||
|
||||
# Check if any of the beam searches has ended during this
|
||||
# growth step. Also if top beam (most probable) has ended
|
||||
# for one element of the batch.
|
||||
is_finished = topk_token_ids.eq(self.end_token_id)
|
||||
self.enforce_max_length()
|
||||
is_top_beam_finished = is_finished[:, 0].eq(1)
|
||||
|
||||
# Save the finished searches
|
||||
if is_finished.any():
|
||||
predictions = self.growing_beam.view(
|
||||
-1, self.beam_size, self.growing_beam.size(1)
|
||||
)
|
||||
for i in range(is_finished.size(0)):
|
||||
if is_top_beam_finished[i]:
|
||||
is_finished[i].fill_(1)
|
||||
finished_hyp = is_finished[i].nonzero().view(-1)
|
||||
|
||||
# Store finished hypotheses for this batch.
|
||||
b = self.batch_offset[i]
|
||||
for j in finished_hyp:
|
||||
self.hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
|
||||
|
||||
# If the batch reached the end, save the best hypotheses
|
||||
# in terms of length-penalized score.
|
||||
if is_top_beam_finished[i]:
|
||||
best_hyp = sorted(
|
||||
self.hypotheses[b], key=lambda x: x[0], reverse=True
|
||||
)
|
||||
best_score, best_prediction = best_hyp[0]
|
||||
self.results["scores"][b].append(best_score)
|
||||
self.results["predictions"][b].append(best_prediction)
|
||||
|
||||
non_finished = is_top_beam_finished.eq(0).nonzero().view(-1)
|
||||
if len(non_finished) == 0:
|
||||
self.is_done = True
|
||||
|
||||
# Remove finished batches for the next step.
|
||||
topk_log_probabilities = topk_log_probabilities.index_select(
|
||||
0, non_finished
|
||||
)
|
||||
self.batch_offset = self.batch_offset.index_select(0, non_finished)
|
||||
self.growing_beam = predictions.index_select(0, non_finished).view(
|
||||
-1, self.growing_beam.size(-1)
|
||||
)
|
||||
|
||||
surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished)
|
||||
|
||||
return surviving_beams_rows
|
||||
|
||||
def forward(self, encoder_input_ids, **kwargs):
|
||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||
# that apply to the model as whole.
|
||||
# We let the specific kwargs override the common ones in case of conflict.
|
||||
kwargs_encoder = {
|
||||
argument: value
|
||||
argument[len("encoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("decoder_")
|
||||
if argument.startswith("encoder_")
|
||||
}
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_common = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
|
||||
}
|
||||
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
|
||||
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
|
||||
|
||||
batch_size, _ = input_ids.size(0)
|
||||
|
||||
# Variables that keep track of the status of the search
|
||||
hypotheses = [[] for _ in range(batch_size)]
|
||||
batch_offset = torch.arange(batch_size, dtype=torch.long)
|
||||
beam_offset = torch.arange(
|
||||
0,
|
||||
batch_size * self.beam_size,
|
||||
step=self.beam_size,
|
||||
dtype=torch.long,
|
||||
)
|
||||
growing_beam = torch.full(
|
||||
(batch_size * self.beam_size, 1),
|
||||
self.start_token_id,
|
||||
dtype=torch.long,
|
||||
)
|
||||
topk_log_probabilities = torch.tensor(
|
||||
[0.0] + [float("-inf")] * (self.beam_size - 1),
|
||||
dtype=torch.float,
|
||||
).repeat(batch_size)
|
||||
|
||||
# Forward pass on the encoder
|
||||
encoder_outputs = self.encoder(input_ids, kwargs_encoder)
|
||||
# forward pass on the encoder
|
||||
encoder_outputs = self.model.encoder.forward(encoder_input_ids, kwargs_encoder)
|
||||
kwargs_decoder["encoder_hidden_states"] = tile(
|
||||
encoder_outputs, self.beam_size, dim=0
|
||||
)
|
||||
|
||||
results = {}
|
||||
results["predictions"] = [[] for _ in batch_size]
|
||||
results["scores"] = [[] for _ in batch_size]
|
||||
|
||||
# grow the beam by generating sequences in an autoregressive way
|
||||
self.growing_beam = torch.full(
|
||||
(self.batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long
|
||||
)
|
||||
for step in range(self.max_length):
|
||||
decoder_input = growing_beam[:, -1]
|
||||
outputs = self.decoder(decoder_input, kwargs_decoder)
|
||||
decoder_input = self.growing_beam[:, -1]
|
||||
outputs = self.model.decoder(decoder_input, kwargs_decoder)
|
||||
log_probabilities = torch.nn.functional.log_softmax(outputs[1])
|
||||
vocab_size = log_probabilities.size(-1)
|
||||
surviving_beams_rows = self.step(log_probabilities)
|
||||
if self.is_done:
|
||||
break
|
||||
|
||||
# The batch size changes as some beams finish so we define:
|
||||
_B = log_probabilities.size(0) // self.beam_size
|
||||
|
||||
# Multiply each beam probability with the probability of the
|
||||
# next token (conditioned on the words in the beam).
|
||||
log_probabilities += topk_log_probabilities.view(-1, 1)
|
||||
|
||||
# if the beam has not attained the minimum required length we
|
||||
# make the end token arbitrarily unlikely.
|
||||
if step < self.min_length:
|
||||
log_probabilities[self.end_token_id] = -1e20
|
||||
|
||||
# Remove repeating tri-grams
|
||||
if(self.args.block_trigram):
|
||||
if(step + 1 > 3):
|
||||
for i in range(_B * self.beam_size):
|
||||
tokens = [t for t in growing_beam[i]]
|
||||
trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)]
|
||||
last_trigram = tuple(trigrams[-1])
|
||||
if last_trigram in trigrams[:-1]:
|
||||
log_probabilities[i] = -1e20
|
||||
|
||||
# Find the `beam_size` (previous_beam + token) combinations with
|
||||
# the highest score
|
||||
topk_log_probabilities, topk_ids = log_probabilities.topk(
|
||||
log_probabilities.view(_B, self.beam_size * vocab_size),
|
||||
self.beam_size,
|
||||
dim=1
|
||||
)
|
||||
|
||||
# Apply the length penalty. The +1 accounts for the [EOS] token
|
||||
# that will be added if the beam ends.
|
||||
length_penalty = ((5.0 + (step + 1)) / 6.0) ** self.alpha
|
||||
topk_scores = topk_log_probabilities / length_penalty
|
||||
|
||||
# Retrieve the corresponding respective beam and token id
|
||||
# topk_token_ids[i] will be added to topk_beam_ids[i]
|
||||
topk_beam_ids = topk_ids.div(vocab_size)
|
||||
topk_token_ids = topk_ids.fmod(vocab_size)
|
||||
|
||||
# Retrieve the row index of the surviving beams in the original
|
||||
# view of the log_probabilities tensor
|
||||
surviving_beams_rows = (
|
||||
topk_beam_ids + beam_offset[:_B].view(-1, 1)
|
||||
).view(-1)
|
||||
|
||||
# Append the last predictions
|
||||
growing_beam = torch.cat(
|
||||
[
|
||||
growing_beam.index_select(0, surviving_beams_rows),
|
||||
topk_token_ids.view(-1, 1),
|
||||
],
|
||||
1,
|
||||
)
|
||||
|
||||
# Check if any of the beam searches has ended during this
|
||||
# growth step. Also if top beam (most probable) has ended
|
||||
# for one element of the batch.
|
||||
is_finished = topk_token_ids.eq(self.end_token_id)
|
||||
if step + 1 == self.max_length:
|
||||
is_finished.fill_(1)
|
||||
is_top_beam_finished = is_finished[:, 0].eq(1)
|
||||
|
||||
# Save the finished searches
|
||||
if is_finished.any():
|
||||
predictions = growing_beam.view(-1, self.beam_size, growing_beam.size(1))
|
||||
for i in range(is_finished.size(0)):
|
||||
if is_top_beam_finished[i]:
|
||||
is_finished[i].fill_(1)
|
||||
finished_hyp = is_finished[i].nonzero().view(-1)
|
||||
|
||||
# Store finished hypotheses for this batch.
|
||||
b = batch_offset[i]
|
||||
for j in finished_hyp:
|
||||
hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
|
||||
|
||||
# If the batch reached the end, save the best hypotheses
|
||||
# in terms of length-penalized score.
|
||||
if is_top_beam_finished[i]:
|
||||
best_hyp = sorted(
|
||||
hypotheses[b], key=lambda x: x[0], reverse=True
|
||||
)
|
||||
best_score, best_prediction = best_hyp[0]
|
||||
results["scores"][b].append(best_score)
|
||||
results["predictions"][b].append(best_prediction)
|
||||
|
||||
non_finished = is_top_beam_finished.eq(0).nonzero().view(-1)
|
||||
if len(non_finished) == 0:
|
||||
break
|
||||
|
||||
# Remove finished batches for the next step.
|
||||
topk_log_probabilities = topk_log_probabilities.index_select(0, non_finished)
|
||||
batch_offset = batch_offset.index_select(0, non_finished)
|
||||
growing_beam = predictions.index_select(0, non_finished).view(
|
||||
-1, growing_beam.size(-1)
|
||||
)
|
||||
|
||||
# Re-order the state for the next pass
|
||||
surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished)
|
||||
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
|
||||
"encoder_hidden_states"
|
||||
].index_select(0, surviving_beams_rows)
|
||||
|
||||
return results
|
||||
return self.results
|
||||
|
||||
def remove_repeating_trigrams(self, log_probabilities, _B):
|
||||
if(self._step + 1 > 3):
|
||||
for i in range(_B * self.beam_size):
|
||||
tokens = [t for t in self.growing_beam[i]]
|
||||
trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)]
|
||||
last_trigram = tuple(trigrams[-1])
|
||||
if last_trigram in trigrams[:-1]:
|
||||
log_probabilities[i] = -1e20
|
||||
|
||||
def enforce_min_length(self):
|
||||
if self._step < self.min_length:
|
||||
self.log_probabilities[self.end_token_id] = -1e20
|
||||
|
||||
def enforce_max_length(self):
|
||||
if self._step + 1 == self.max_length:
|
||||
self.is_finished.fill_(1)
|
||||
|
||||
def length_penalty(self):
|
||||
return ((5.0 + (self._step + 1)) / 6.0) ** self.alpha
|
||||
|
||||
|
||||
def tile(x, count, dim=0):
|
||||
|
||||
@@ -632,6 +632,8 @@ class BertModel(BertPreTrainedModel):
|
||||
"""
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
@@ -660,12 +662,15 @@ class BertModel(BertPreTrainedModel):
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# If a 2D encoder attention mask is provided for the cross-attention
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
encoder_attention_mask = encoder_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0
|
||||
if encoder_attention_mask.dim() == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
if encoder_attention_mask.dim() == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@@ -687,7 +692,7 @@ class BertModel(BertPreTrainedModel):
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask)
|
||||
encoder_attention_mask=encoder_extended_attention_mask)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
@@ -788,8 +793,10 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
**masked_lm_loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Masked language modeling loss.
|
||||
**next_token_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Next token prediction loss.
|
||||
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
@@ -854,13 +861,13 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
|
||||
if lm_labels is not None:
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
prediction_scores = prediction_scores[:, :-1, :]
|
||||
lm_labels = lm_labels[:, 1:]
|
||||
prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||
lm_labels = lm_labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
seq2seq_loss = loss_fct(prediction_scores.reshape(-1, self.config.vocab_size), lm_labels.reshape(-1))
|
||||
outputs = (seq2seq_loss,) + outputs
|
||||
next_token_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
|
||||
outputs = (next_token_loss,) + outputs
|
||||
|
||||
return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
|
||||
return outputs # (masked_lm_loss), (next_token_loss), prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||
|
||||
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class PreTrainedSeq2seq(nn.Module):
|
||||
r"""
|
||||
:class:`~transformers.Seq2seq` is a generic model class that will be
|
||||
:class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be
|
||||
instantiated as a Seq2seq model with one of the base model classes of
|
||||
the library as encoder and (optionally) as decoder when created with
|
||||
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
|
||||
@@ -49,8 +49,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
*model_args,
|
||||
**kwargs
|
||||
):
|
||||
r""" Instantiates an encoder and a decoder from one or two base classes
|
||||
of the library from pre-trained model checkpoints.
|
||||
r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.
|
||||
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
@@ -111,35 +110,44 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
|
||||
"""
|
||||
|
||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||
# decoder-specific it the key starts with `decoder_`
|
||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||
# that apply to the model as a whole.
|
||||
# We let the specific kwargs override the common ones in case of conflict.
|
||||
kwargs_encoder = {
|
||||
argument: value
|
||||
argument[len("encoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("decoder_")
|
||||
if argument.startswith("encoder_")
|
||||
}
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_") :]: value
|
||||
argument[len("decoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_common = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
|
||||
}
|
||||
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
|
||||
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
|
||||
|
||||
# Load and initialize the encoder and decoder
|
||||
# The distinction between encoder and decoder at the model level is made
|
||||
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||
encoder = kwargs_encoder.pop("encoder_model", None)
|
||||
encoder = kwargs_encoder.pop("model", None)
|
||||
if encoder is None:
|
||||
kwargs_encoder["is_decoder"] = False
|
||||
encoder = AutoModel.from_pretrained(
|
||||
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
|
||||
)
|
||||
encoder.config.is_decoder = False
|
||||
|
||||
decoder = kwargs_decoder.pop("model", None)
|
||||
if decoder is None:
|
||||
kwargs_decoder["is_decoder"] = True
|
||||
decoder = AutoModelWithLMHead.from_pretrained(
|
||||
decoder_pretrained_model_name_or_path, **kwargs_decoder
|
||||
)
|
||||
decoder.config.is_decoder = True
|
||||
|
||||
model = cls(encoder, decoder)
|
||||
|
||||
@@ -169,37 +177,60 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
|
||||
Indices of decoder input sequence tokens in the vocabulary.
|
||||
"""
|
||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||
# decoder-specific it the key starts with `decoder_`
|
||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||
# that apply to the model as whole.
|
||||
# We let the specific kwargs override the common ones in case of conflict.
|
||||
kwargs_encoder = {
|
||||
argument: value
|
||||
argument[len("encoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("decoder_")
|
||||
if argument.startswith("encoder_")
|
||||
}
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_") :]: value
|
||||
argument[len("decoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_common = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
|
||||
}
|
||||
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
|
||||
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
|
||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[0][
|
||||
-1
|
||||
] # output of the encoder *stack*
|
||||
encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state
|
||||
else:
|
||||
encoder_outputs = ()
|
||||
|
||||
# Decode
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states[None, :, :]
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
|
||||
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
|
||||
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
|
||||
class Model2Model(PreTrainedSeq2seq):
|
||||
r"""
|
||||
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
|
||||
where both of the encoder and decoder are of the same family. If the
|
||||
name of or that path to a pretrained model is specified the encoder and
|
||||
the decoder will be initialized with the pretrained weight (the
|
||||
cross-attention will be intialized randomly if its weights are not
|
||||
present).
|
||||
|
||||
It is possible to override this behavior and initialize, say, the decoder randomly
|
||||
by creating it beforehand as follows
|
||||
|
||||
config = BertConfig.from_pretrained()
|
||||
decoder = BertForMaskedLM(config)
|
||||
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Model2Model, self).__init__(*args, **kwargs)
|
||||
self.tie_weights()
|
||||
@@ -235,14 +266,10 @@ class Model2Model(PreTrainedSeq2seq):
|
||||
model = super(Model2Model, cls).from_pretrained(
|
||||
encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Some architectures require for the decoder to be initialized randomly
|
||||
# before fine-tuning.
|
||||
if kwargs.get("decoder_initialize_randomly", False):
|
||||
model.decoder.init_weights()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user