Black 20 release
This commit is contained in:
@@ -63,7 +63,9 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
|
||||
|
||||
|
||||
def trim_batch(
|
||||
input_ids, pad_token_id, attention_mask=None,
|
||||
input_ids,
|
||||
pad_token_id,
|
||||
attention_mask=None,
|
||||
):
|
||||
"""Remove columns that are populated exclusively by pad_token_id"""
|
||||
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
||||
|
||||
Reference in New Issue
Block a user