Remove head mask in generative models (#35786)
* just squash into one commit * delete print
This commit is contained in:
committed by
GitHub
parent
0173a99e73
commit
955e61b0da
@@ -52,15 +52,12 @@ def prepare_opt_inputs_dict(
|
||||
decoder_input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
}
|
||||
|
||||
|
||||
@@ -156,10 +153,9 @@ class OPTModelTester:
|
||||
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
head_mask = inputs_dict["head_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
@@ -187,7 +183,7 @@ class OPTModelTester:
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
# test no attention_mask works
|
||||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
_, past_key_values = outputs.to_tuple()
|
||||
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user