share pretrained embeddings
This commit is contained in:
committed by
Julien Chaumond
parent
9660ba1cbd
commit
ba089c780b
@@ -136,18 +136,11 @@ def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
|||||||
as specified in [1] by using `[SEP] [CLS]` tokens to separate
|
as specified in [1] by using `[SEP] [CLS]` tokens to separate
|
||||||
sentences.
|
sentences.
|
||||||
"""
|
"""
|
||||||
story_lines_token_ids = [
|
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
|
||||||
tokenizer.build_inputs_with_special_tokens(tokenizer.encode(line))
|
|
||||||
for line in story_lines
|
|
||||||
]
|
|
||||||
summary_lines_token_ids = [
|
|
||||||
tokenizer.build_inputs_with_special_tokens(tokenizer.encode(line))
|
|
||||||
for line in summary_lines
|
|
||||||
]
|
|
||||||
|
|
||||||
story_token_ids = [
|
story_token_ids = [
|
||||||
token for sentence in story_lines_token_ids for token in sentence
|
token for sentence in story_lines_token_ids for token in sentence
|
||||||
]
|
]
|
||||||
|
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
|
||||||
summary_token_ids = [
|
summary_token_ids = [
|
||||||
token for sentence in summary_lines_token_ids for token in sentence
|
token for sentence in summary_lines_token_ids for token in sentence
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -10,3 +10,5 @@ regex
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
# For XLM
|
# For XLM
|
||||||
sacremoses
|
sacremoses
|
||||||
|
# For ROUGE
|
||||||
|
pyrouge
|
||||||
|
|||||||
@@ -26,27 +26,31 @@ Use Beam Search to generate sequences using encoder-decoder models.
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BeamSearch(nn.Module):
|
class BeamSearch(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
bos_token_id,
|
||||||
|
pad_token_id,
|
||||||
|
eos_token_id,
|
||||||
|
batch_size,
|
||||||
beam_size,
|
beam_size,
|
||||||
min_length,
|
min_length,
|
||||||
max_length,
|
max_length,
|
||||||
batch_size=1,
|
|
||||||
alpha=0,
|
alpha=0,
|
||||||
block_repeating_trigrams=True,
|
block_repeating_trigrams=True,
|
||||||
|
device=torch.device("cpu"),
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Inputs:
|
Inputs:
|
||||||
**model**: instance of ``transformers.PreTrainedEncoderDecoder``
|
**model**: instance of ``transformers.PreTrainedEncoderDecoder``
|
||||||
The pretrained encoder-decoder model that will be used to generate the sequences.
|
The pretrained encoder-decoder model that will be used to generate the sequences.
|
||||||
**tokenizer**: instance of ``transformers.PreTrainedTokenizer``
|
|
||||||
The pretrained tokenizer associated to the model used in the encoder-decoder. We only
|
|
||||||
support encoder-decoder that use the same tokenizer for encoder and decoder. The tokenizer
|
|
||||||
needs to be initialized or this function will raise and exception.
|
|
||||||
**batch_size**: (`optional`) int
|
**batch_size**: (`optional`) int
|
||||||
Batch size of the inputs. The value is set automatically when calling `forward`.
|
Batch size of the inputs. The value is set automatically when calling `forward`.
|
||||||
**beam_size**: int
|
**beam_size**: int
|
||||||
@@ -64,11 +68,11 @@ class BeamSearch(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super(BeamSearch, self).__init__()
|
super(BeamSearch, self).__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.device = device
|
||||||
|
|
||||||
self.bos_token_id = tokenizer.bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
self.eos_token_id = tokenizer.eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
self.pad_token_id = tokenizer.pad_token_id
|
self.pad_token_id = pad_token_id
|
||||||
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.beam_size = beam_size
|
self.beam_size = beam_size
|
||||||
@@ -90,15 +94,24 @@ class BeamSearch(nn.Module):
|
|||||||
def _init_beam_state(self, batch_size):
|
def _init_beam_state(self, batch_size):
|
||||||
""" (re-)Initialize the state of the beams. """
|
""" (re-)Initialize the state of the beams. """
|
||||||
self.hypotheses = [[] for _ in range(batch_size)]
|
self.hypotheses = [[] for _ in range(batch_size)]
|
||||||
self.batch_offset = torch.arange(batch_size, dtype=torch.long)
|
self.batch_offset = torch.arange(batch_size, dtype=torch.long, device=self.device)
|
||||||
self.beam_offset = torch.arange(
|
self.beam_offset = torch.arange(
|
||||||
0, batch_size * self.beam_size, step=self.beam_size, dtype=torch.long
|
0,
|
||||||
|
batch_size * self.beam_size,
|
||||||
|
step=self.beam_size,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.growing_beams = torch.full(
|
self.growing_beams = torch.full(
|
||||||
(batch_size * self.beam_size, 1), self.bos_token_id, dtype=torch.long
|
(batch_size * self.beam_size, 1),
|
||||||
|
self.bos_token_id,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.topk_log_probabilities = torch.tensor(
|
self.topk_log_probabilities = torch.tensor(
|
||||||
[0.0] + [float("-inf")] * (self.beam_size - 1), dtype=torch.float
|
[0.0] + [float("-inf")] * (self.beam_size - 1),
|
||||||
|
dtype=torch.float,
|
||||||
|
device=self.device,
|
||||||
).repeat(batch_size)
|
).repeat(batch_size)
|
||||||
self.results = {
|
self.results = {
|
||||||
"predictions": [[] for _ in range(batch_size)],
|
"predictions": [[] for _ in range(batch_size)],
|
||||||
@@ -136,28 +149,37 @@ class BeamSearch(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# forward pass on the encoder
|
# forward pass on the encoder
|
||||||
encoder_outputs = self.model.encoder.forward(encoder_input_ids, kwargs_encoder)
|
encoder_outputs = self.model.encoder(encoder_input_ids, **kwargs_encoder)
|
||||||
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
kwargs_decoder["encoder_hidden_states"] = tile(
|
kwargs_decoder["encoder_hidden_states"] = tile(
|
||||||
encoder_outputs, self.beam_size, dim=0
|
encoder_hidden_states, self.beam_size, dim=0
|
||||||
|
)
|
||||||
|
kwargs_decoder["encoder_attention_mask"] = tile(
|
||||||
|
kwargs_encoder["attention_mask"], self.beam_size, dim=0
|
||||||
)
|
)
|
||||||
|
|
||||||
# grow the beam by generating sequences in an autoregressive way
|
# grow the beam by generating sequences in an autoregressive way
|
||||||
batch_size = encoder_input_ids.size(0)
|
batch_size, block_size = encoder_input_ids.size()
|
||||||
self._init_beam_state(batch_size)
|
self._init_beam_state(batch_size)
|
||||||
for step in range(self.max_length):
|
for step in range(self.max_length):
|
||||||
# prepare the decoder input
|
# Add padding tokens
|
||||||
decoder_input = fit_to_block_size(
|
decoder_input = torch.full(
|
||||||
self.growing_beams, self.tokenizer.pad_token_id
|
(self.growing_beams.size(0), block_size),
|
||||||
)
|
self.pad_token_id,
|
||||||
kwargs_decoder["decoder_lm_labels"] = build_lm_labels(
|
dtype=torch.long,
|
||||||
decoder_input, self.tokenizer.pad_token_id
|
device=self.growing_beams.device,
|
||||||
)
|
|
||||||
kwargs_decoder["decoder_attention_mask"] = build_mask(
|
|
||||||
decoder_input, self.tokenizer.pad_token_id
|
|
||||||
)
|
)
|
||||||
|
decoder_input[:, : self.growing_beams.size(1)] = self.growing_beams
|
||||||
|
|
||||||
outputs = self.model.decoder(decoder_input, kwargs_decoder)
|
# compute decoder_attention_mask
|
||||||
log_probabilities = torch.nn.functional.log_softmax(outputs[1])
|
decoder_mask = torch.ones_like(decoder_input)
|
||||||
|
idx_pad_tokens = decoder_input == self.pad_token_id
|
||||||
|
decoder_mask[idx_pad_tokens] = 0
|
||||||
|
kwargs_decoder["attention_mask"] = decoder_mask
|
||||||
|
|
||||||
|
outputs = self.model.decoder(decoder_input, **kwargs_decoder)
|
||||||
|
last_token_scores = outputs[0][:, -1, :].squeeze(1)
|
||||||
|
log_probabilities = torch.nn.functional.log_softmax(last_token_scores, dim=0)
|
||||||
surviving_beams_rows = self.grow(log_probabilities)
|
surviving_beams_rows = self.grow(log_probabilities)
|
||||||
if self.is_done:
|
if self.is_done:
|
||||||
break
|
break
|
||||||
@@ -189,13 +211,13 @@ class BeamSearch(nn.Module):
|
|||||||
|
|
||||||
# Find the `beam_size` (previous_beam + token) combinations with
|
# Find the `beam_size` (previous_beam + token) combinations with
|
||||||
# the highest score
|
# the highest score
|
||||||
topk_log_probabilities, topk_ids = torch.topk(
|
self.topk_log_probabilities, topk_ids = torch.topk(
|
||||||
log_probabilities.view(_B, self.beam_size * vocab_size), self.beam_size, dim=1
|
log_probabilities.view(_B, self.beam_size * vocab_size), self.beam_size, dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply the length penalty. The +1 accounts for the [EOS] token
|
# Apply the length penalty. The +1 accounts for the [EOS] token
|
||||||
# that will be added if the beam ends.
|
# that will be added if the beam ends.
|
||||||
topk_scores = topk_log_probabilities
|
topk_scores = self.topk_log_probabilities
|
||||||
if self.apply_length_penalty:
|
if self.apply_length_penalty:
|
||||||
topk_scores /= self._length_penalty()
|
topk_scores /= self._length_penalty()
|
||||||
|
|
||||||
@@ -337,8 +359,9 @@ def fit_to_block_size(sequence, block_size, pad_token_id):
|
|||||||
if len(sequence) > block_size:
|
if len(sequence) > block_size:
|
||||||
return sequence[:block_size]
|
return sequence[:block_size]
|
||||||
else:
|
else:
|
||||||
sequence.extend([pad_token_id] * (block_size - len(sequence)))
|
return torch.cat(
|
||||||
return sequence
|
(sequence, torch.tensor([pad_token_id] * (block_size - len(sequence)))), dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_lm_labels(sequence, pad_token_id):
|
def build_lm_labels(sequence, pad_token_id):
|
||||||
|
|||||||
Reference in New Issue
Block a user