Refactor BartModel so that input checks are handled within BartEncoder and BartDecoder
This commit is contained in:
BIN
src/.DS_Store
vendored
Normal file
BIN
src/.DS_Store
vendored
Normal file
Binary file not shown.
@@ -271,6 +271,12 @@ class BartEncoder(nn.Module):
|
|||||||
- **all_attentions** (List[Tensor]): Attention weights for each layer.
|
- **all_attentions** (List[Tensor]): Attention weights for each layer.
|
||||||
During training might not be of length n_layers because of layer dropout.
|
During training might not be of length n_layers because of layer dropout.
|
||||||
"""
|
"""
|
||||||
|
# check attention mask and invert
|
||||||
|
if attention_mask is not None:
|
||||||
|
assert attention_mask.dim() == 2
|
||||||
|
|
||||||
|
attention_mask = (1.0 - attention_mask.long()) * -10000.0
|
||||||
|
assert attention_mask.max() <= 0
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
embed_pos = self.embed_positions(input_ids)
|
embed_pos = self.embed_positions(input_ids)
|
||||||
x = inputs_embeds + embed_pos
|
x = inputs_embeds + embed_pos
|
||||||
@@ -448,6 +454,13 @@ class BartDecoder(nn.Module):
|
|||||||
- hidden states
|
- hidden states
|
||||||
- attentions
|
- attentions
|
||||||
"""
|
"""
|
||||||
|
# check attention mask and invert
|
||||||
|
if encoder_padding_mask is not None:
|
||||||
|
assert encoder_padding_mask.dim() == 2
|
||||||
|
|
||||||
|
encoder_padding_mask = (1.0 - encoder_padding_mask.long()) * -10000.0
|
||||||
|
assert encoder_padding_mask.max() <= 0
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
positions = self.embed_positions(input_ids, generation_mode=self.generation_mode)
|
positions = self.embed_positions(input_ids, generation_mode=self.generation_mode)
|
||||||
|
|
||||||
@@ -823,11 +836,6 @@ class BartModel(PretrainedBartModel):
|
|||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
decoder_cached_states=None,
|
decoder_cached_states=None,
|
||||||
):
|
):
|
||||||
if attention_mask is not None:
|
|
||||||
assert attention_mask.dim() == 2
|
|
||||||
|
|
||||||
attention_mask = (1.0 - attention_mask.long()) * -10000.0
|
|
||||||
assert attention_mask.max() <= 0
|
|
||||||
|
|
||||||
# make masks if user doesn't supply
|
# make masks if user doesn't supply
|
||||||
if not self.decoder.generation_mode:
|
if not self.decoder.generation_mode:
|
||||||
|
|||||||
Reference in New Issue
Block a user