finalize generation merge

This commit is contained in:
Patrick von Platen
2020-03-11 11:53:36 +01:00
parent 1ba21f96ca
commit a332cc9f7f
4 changed files with 10 additions and 13 deletions

View File

@@ -40,8 +40,9 @@ class BartConfig(PretrainedConfig):
self,
activation_dropout=0.0,
vocab_size=50265,
bos_token_id=0,
pad_token_id=1,
eos_token_id=2,
eos_token_ids=[2],
d_model=1024,
encoder_ffn_dim=4096,
encoder_layers=12,
@@ -58,7 +59,6 @@ class BartConfig(PretrainedConfig):
classifier_dropout=0.0,
output_past=False,
num_labels=3,
bos_token_id=0,
is_encoder_decoder=True,
**common_kwargs
):
@@ -73,12 +73,12 @@ class BartConfig(PretrainedConfig):
output_past=output_past,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_ids=eos_token_ids,
is_encoder_decoder=is_encoder_decoder,
**common_kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
self.eos_token_id = eos_token_id
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = self.num_hidden_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads

View File

@@ -962,8 +962,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
def prepare_scores_for_generation(self, scores, cur_len, max_length):
if cur_len == 1:
self._force_token_ids_generation(scores, self.config.bos_token_id)
if cur_len == max_length - 1:
self._force_token_ids_generation(scores, self.config.eos_token_ids)
if cur_len == max_length - 1 and self.config.eos_token_ids[0] is not None:
self._force_token_ids_generation(scores, self.config.eos_token_ids[0])
return scores
@staticmethod
@@ -1056,7 +1056,7 @@ class BartForSequenceClassification(PretrainedBartModel):
encoder_outputs=encoder_outputs,
)
x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id)
eos_mask = input_ids.eq(self.config.eos_token_ids[0])
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]

View File

@@ -840,14 +840,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder:
eos_token_id = eos_token_ids[0]
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
assert eos_token_id is not None, "Encoder Decoder Models need to have a eos_token_id"
# encoder decoder need to start with empty input_ids and copy the input_ids to encoder_inputs
encoder_inputs = input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
# eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
bos_token_id,
dtype=torch.long,
device=next(self.parameters()).device,