Remove head mask in generative models (#35786)

* just squash into one commit

* delete print
This commit is contained in:
Raushan Turganbay
2025-05-15 10:44:19 +02:00
committed by GitHub
parent 0173a99e73
commit 955e61b0da
47 changed files with 103 additions and 294 deletions

View File

@@ -52,15 +52,12 @@ def prepare_opt_inputs_dict(
decoder_input_ids=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
}
@@ -156,10 +153,9 @@ class OPTModelTester:
input_ids = inputs_dict["input_ids"]
attention_mask = inputs_dict["attention_mask"]
head_mask = inputs_dict["head_mask"]
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
@@ -187,7 +183,7 @@ class OPTModelTester:
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
# test no attention_mask works
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
_, past_key_values = outputs.to_tuple()
output_from_no_past = model(next_input_ids)["last_hidden_state"]