Switch from using sum for flattening lists of lists in group_texts (#14472)

* remove sum for list flattening

* change to chain(*)

* make chain object a list

* delete empty lines

per sgugger's suggestions

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Nicholas Broad <nicholas@nmbroad.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Nicholas Broad
2021-11-22 16:17:26 -05:00
committed by GitHub
parent 0b7d053c13
commit 69e16abf98
15 changed files with 35 additions and 20 deletions

View File

@@ -25,6 +25,7 @@ import os
import sys
import time
from dataclasses import dataclass, field
from itertools import chain
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
from pathlib import Path
@@ -453,7 +454,7 @@ if __name__ == "__main__":
# max_seq_length.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.