Switch from using sum for flattening lists of lists in group_texts (#14472)

* remove sum for list flattening

* change to chain(*)

* make chain object a list

* delete empty lines

per sgugger's suggestions

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Nicholas Broad <nicholas@nmbroad.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Nicholas Broad
2021-11-22 16:17:26 -05:00
committed by GitHub
parent 0b7d053c13
commit 69e16abf98
15 changed files with 35 additions and 20 deletions

View File

@@ -26,6 +26,7 @@ import math
import os
import sys
from dataclasses import dataclass, field
from itertools import chain
from typing import Optional
import datasets
@@ -408,7 +409,7 @@ def main():
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.