Bart: fix layerdrop and cached decoder_input_ids for generation (#2969)
This commit is contained in:
@@ -470,10 +470,6 @@ class BartDecoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# embed positions
|
# embed positions
|
||||||
positions = self.embed_positions(input_ids)
|
positions = self.embed_positions(input_ids)
|
||||||
|
|
||||||
if decoder_cached_states is not None:
|
|
||||||
input_ids = input_ids[:, -1:]
|
|
||||||
positions = positions[:, -1:]
|
|
||||||
x = self.embed_tokens(input_ids)
|
x = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
if positions is not None:
|
if positions is not None:
|
||||||
@@ -491,7 +487,7 @@ class BartDecoder(nn.Module):
|
|||||||
decoder_layer # type: DecoderLayer
|
decoder_layer # type: DecoderLayer
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
dropout_probability = random.uniform(0, 1)
|
dropout_probability = random.uniform(0, 1)
|
||||||
if self.training and (dropout_probability > self.layerdrop):
|
if self.training and (dropout_probability < self.layerdrop):
|
||||||
continue
|
continue
|
||||||
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
|
layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None
|
||||||
x, layer_self_attn, layer_past = decoder_layer.forward(
|
x, layer_self_attn, layer_past = decoder_layer.forward(
|
||||||
@@ -940,7 +936,7 @@ class BartForMaskedLM(PretrainedBartModel):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_inputs_for_generation(input_ids, past, **kwargs):
|
def prepare_inputs_for_generation(input_ids, past, **kwargs):
|
||||||
return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids}
|
return {"input_ids": input_ids, "decoder_cached_states": past, "decoder_input_ids": input_ids[:, -1:]}
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|||||||
@@ -251,6 +251,7 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
output_past=True,
|
output_past=True,
|
||||||
)
|
)
|
||||||
lm_model = BartForMaskedLM(config)
|
lm_model = BartForMaskedLM(config)
|
||||||
|
lm_model.eval()
|
||||||
new_input_ids = lm_model.generate(input_ids)
|
new_input_ids = lm_model.generate(input_ids)
|
||||||
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 20))
|
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], 20))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user