Return correct Bart hidden state tensors (#8747)

* bart output hidden states upstream

* same w/ decoder

* add tests

* fix prophetnet

* fix gpt2 and ctrl

* fix fstm and skip test for reformer and longformer

* fix all models

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Joe Davison
2020-11-25 16:06:04 -05:00
committed by GitHub
parent 138f45c184
commit 369f1d77b4
14 changed files with 199 additions and 54 deletions

View File

@@ -441,13 +441,12 @@ class CTRLModel(CTRLPreTrainedModel):
hidden_states = self.dropout(hidden_states)
output_shape = input_shape + (inputs_embeds.size(-1),)
presents = () if use_cache else None
all_hidden_states = () if output_hidden_states else None
all_attentions = [] if output_attentions else None
all_attentions = () if output_attentions else None
for i, (h, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = h(
hidden_states,
mask,
@@ -462,18 +461,12 @@ class CTRLModel(CTRLPreTrainedModel):
presents = presents + (present,)
if output_attentions:
all_attentions.append(outputs[2])
all_attentions += (outputs[2],)
hidden_states = self.layernorm(hidden_states)
hidden_states = hidden_states.view(*output_shape)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)