Bart: fix layerdrop and cached decoder_input_ids for generation (#2969)

This commit is contained in:
Sam Shleifer
2020-02-22 16:25:04 -05:00
committed by GitHub
parent c36416e53c
commit 92487a1dc0
2 changed files with 3 additions and 6 deletions

View File

@@ -470,10 +470,6 @@ class BartDecoder(nn.Module):
"""
# embed positions
positions = self.embed_positions(input_ids)
if decoder_cached_states is not None:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:]
x = self.embed_tokens(input_ids)
if positions is not None:
@@ -491,7 +487,7 @@ class BartDecoder(nn.Module):
decoder_layer # type: DecoderLayer
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability > self.layerdrop):
if self.training and (dropout_probability < self.layerdrop):
continue
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
x, layer_self_attn, layer_past = decoder_layer.forward(
@@ -940,7 +936,7 @@ class BartForMaskedLM(PretrainedBartModel):
@staticmethod
def prepare_inputs_for_generation(input_ids, past, **kwargs):
return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids}
return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids[:, -1:]}
def get_output_embeddings(self):
return self.lm_head