Switch from using sum for flattening lists of lists in group_texts (#14472)
* remove sum for list flattening * change to chain(*) * make chain object a list * delete empty lines per sgugger's suggestions Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Nicholas Broad <nicholas@nmbroad.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -30,6 +30,7 @@ import random
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -406,7 +407,7 @@ def main():
|
||||
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
|
||||
def group_texts(examples):
|
||||
# Concatenate all texts.
|
||||
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
||||
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
||||
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
||||
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
||||
# customize this part to your needs.
|
||||
|
||||
Reference in New Issue
Block a user