Bart: fix layerdrop and cached decoder_input_ids for generation (#2969)
This commit is contained in:
@@ -470,10 +470,6 @@ class BartDecoder(nn.Module):
|
||||
"""
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_ids)
|
||||
|
||||
if decoder_cached_states is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
positions = positions[:, -1:]
|
||||
x = self.embed_tokens(input_ids)
|
||||
|
||||
if positions is not None:
|
||||
@@ -491,7 +487,7 @@ class BartDecoder(nn.Module):
|
||||
decoder_layer # type: DecoderLayer
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
if self.training and (dropout_probability > self.layerdrop):
|
||||
if self.training and (dropout_probability < self.layerdrop):
|
||||
continue
|
||||
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
|
||||
x, layer_self_attn, layer_past = decoder_layer.forward(
|
||||
@@ -940,7 +936,7 @@ class BartForMaskedLM(PretrainedBartModel):
|
||||
|
||||
@staticmethod
|
||||
def prepare_inputs_for_generation(input_ids, past, **kwargs):
|
||||
return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids}
|
||||
return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids[:, -1:]}
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
@@ -251,6 +251,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
output_past=True,
|
||||
)
|
||||
lm_model = BartForMaskedLM(config)
|
||||
lm_model.eval()
|
||||
new_input_ids = lm_model.generate(input_ids)
|
||||
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 20))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user