Fix (non-slow) tests on GPU (torch) (#3024)
* Fix tests on GPU (torch) * Fix bart slow tests Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -86,7 +86,7 @@ def _prepare_bart_decoder_inputs(
|
||||
causal_lm_mask = None
|
||||
new_shape = (bsz, tgt_len, tgt_len)
|
||||
# make it broadcastable so can just be added to the attention coefficients
|
||||
decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape)
|
||||
decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape).to(device=input_ids.device)
|
||||
assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len)
|
||||
return decoder_input_ids, decoder_attn_mask
|
||||
|
||||
|
||||
Reference in New Issue
Block a user