[Bart] Replace config.output_past with use_cache kwarg (#3632)

This commit is contained in:
Sam Shleifer
2020-04-07 19:08:26 -04:00
committed by GitHub
parent e344e3d402
commit 715aa5b135
4 changed files with 25 additions and 26 deletions

View File

@@ -20,7 +20,7 @@ def generate_summaries(
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
): ):
fout = Path(out_file).open("w") fout = Path(out_file).open("w")
model = BartForConditionalGeneration.from_pretrained(model_name, output_past=True,).to(device) model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
tokenizer = BartTokenizer.from_pretrained("bart-large") tokenizer = BartTokenizer.from_pretrained("bart-large")
max_length = 140 max_length = 140

View File

@@ -56,7 +56,6 @@ class BartConfig(PretrainedConfig):
max_position_embeddings=1024, max_position_embeddings=1024,
init_std=0.02, init_std=0.02,
classifier_dropout=0.0, classifier_dropout=0.0,
output_past=False,
num_labels=3, num_labels=3,
is_encoder_decoder=True, is_encoder_decoder=True,
pad_token_id=1, pad_token_id=1,
@@ -72,7 +71,6 @@ class BartConfig(PretrainedConfig):
""" """
super().__init__( super().__init__(
num_labels=num_labels, num_labels=num_labels,
output_past=output_past,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,

View File

@@ -388,7 +388,6 @@ class BartDecoder(nn.Module):
def __init__(self, config: BartConfig, embed_tokens: nn.Embedding): def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
super().__init__() super().__init__()
self.output_past = config.output_past
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.dropout = config.dropout self.dropout = config.dropout
@@ -412,7 +411,7 @@ class BartDecoder(nn.Module):
decoder_padding_mask, decoder_padding_mask,
decoder_causal_mask, decoder_causal_mask,
decoder_cached_states=None, decoder_cached_states=None,
generation_mode=False, use_cache=False,
**unused **unused
): ):
""" """
@@ -438,9 +437,9 @@ class BartDecoder(nn.Module):
encoder_padding_mask = invert_mask(encoder_padding_mask) encoder_padding_mask = invert_mask(encoder_padding_mask)
# embed positions # embed positions
positions = self.embed_positions(input_ids, generation_mode=generation_mode) positions = self.embed_positions(input_ids, use_cache=use_cache)
if generation_mode: if use_cache:
input_ids = input_ids[:, -1:] input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them positions = positions[:, -1:] # happens after we embed them
assert input_ids.ne(self.padding_idx).any() assert input_ids.ne(self.padding_idx).any()
@@ -476,7 +475,7 @@ class BartDecoder(nn.Module):
causal_mask=decoder_causal_mask, causal_mask=decoder_causal_mask,
) )
if self.output_past: if use_cache:
next_decoder_cache.append(layer_past.copy()) next_decoder_cache.append(layer_past.copy())
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states += (x,) all_hidden_states += (x,)
@@ -488,7 +487,7 @@ class BartDecoder(nn.Module):
x = x.transpose(0, 1) x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1) encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
if self.output_past: if use_cache:
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache) next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
else: else:
next_cache = None next_cache = None
@@ -710,9 +709,9 @@ class LearnedPositionalEmbedding(nn.Embedding):
num_embeddings += padding_idx + 1 # WHY? num_embeddings += padding_idx + 1 # WHY?
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def forward(self, input, generation_mode=False): def forward(self, input, use_cache=False):
"""Input is expected to be of size [bsz x seqlen].""" """Input is expected to be of size [bsz x seqlen]."""
if generation_mode: # the position is our current step in the decoded sequence if use_cache: # the position is our current step in the decoded sequence
pos = int(self.padding_idx + input.size(1)) pos = int(self.padding_idx + input.size(1))
positions = input.data.new(1, 1).fill_(pos) positions = input.data.new(1, 1).fill_(pos)
else: else:
@@ -772,11 +771,11 @@ class BartModel(PretrainedBartModel):
encoder_outputs=None, # type: Tuple encoder_outputs=None, # type: Tuple
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_cached_states=None, decoder_cached_states=None,
generation_mode=False, use_cache=False,
): ):
# make masks if user doesn't supply # make masks if user doesn't supply
if not generation_mode: if not use_cache:
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs( decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
self.config, self.config,
input_ids, input_ids,
@@ -799,7 +798,7 @@ class BartModel(PretrainedBartModel):
decoder_padding_mask, decoder_padding_mask,
decoder_causal_mask=causal_mask, decoder_causal_mask=causal_mask,
decoder_cached_states=decoder_cached_states, decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode, use_cache=use_cache,
) )
# Attention and hidden_states will be [] or None if they aren't needed # Attention and hidden_states will be [] or None if they aren't needed
decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple
@@ -841,7 +840,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_cached_states=None, decoder_cached_states=None,
lm_labels=None, lm_labels=None,
generation_mode=False, use_cache=False,
**unused **unused
): ):
r""" r"""
@@ -892,7 +891,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
decoder_cached_states=decoder_cached_states, decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode, use_cache=use_cache,
) )
lm_logits = F.linear(outputs[0], self.model.shared.weight) lm_logits = F.linear(outputs[0], self.model.shared.weight)
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
@@ -918,7 +917,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"decoder_cached_states": decoder_cached_states, "decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"generation_mode": True, "use_cache": True, # change this to avoid caching (presumably for debugging)
} }
def prepare_scores_for_generation(self, scores, cur_len, max_length): def prepare_scores_for_generation(self, scores, cur_len, max_length):
@@ -951,6 +950,10 @@ class BartForConditionalGeneration(PretrainedBartModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return _make_linear_from_emb(self.model.shared) # make it on the fly return _make_linear_from_emb(self.model.shared) # make it on the fly
def _do_output_past(self, *args, **kwargs):
""" We should always use the cache in generate."""
return True
@add_start_docstrings( @add_start_docstrings(
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, """Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,

File diff suppressed because one or more lines are too long