CodeParrot data pretokenization (#16932)
* add pretokenization arguments * add pretokenization script * add support for pretokenized data * reformat code * fix run command for training * fix model call from config * remove a package * add comments on pretokenization in the readme * remove explicit parallelization Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * update readme Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * update readme -remove username Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * update readme -remove username Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * keep data parallelization * reformat code * reformat code * update readme * reformat code * Update examples/research_projects/codeparrot/README.md Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> Co-authored-by: Loubna ben allal <loubnabenallal@gmail.com>
This commit is contained in:
@@ -27,19 +27,34 @@ class ConstantLengthDataset(IterableDataset):
|
||||
seq_length (int): Length of token sequences to return.
|
||||
num_of_sequences: Number of token sequences to keep in buffer.
|
||||
chars_per_token: Number of characters per token used to estimate number of tokens in text buffer.
|
||||
tokenized: If true we use a pretokenized dataset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, tokenizer, dataset, infinite=False, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6
|
||||
self,
|
||||
tokenizer,
|
||||
dataset,
|
||||
infinite=False,
|
||||
seq_length=1024,
|
||||
num_of_sequences=1024,
|
||||
chars_per_token=3.6,
|
||||
tokenized=False,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.concat_token_id = tokenizer.bos_token_id
|
||||
self.dataset = dataset
|
||||
self.seq_length = seq_length
|
||||
self.input_characters = seq_length * chars_per_token * num_of_sequences
|
||||
self.epoch = 0
|
||||
self.infinite = infinite
|
||||
self.current_size = 0
|
||||
self.tokenized = tokenized
|
||||
|
||||
if self.tokenized:
|
||||
self.max_buffer_size = seq_length * num_of_sequences
|
||||
self.content_field = "input_ids"
|
||||
else:
|
||||
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
|
||||
self.content_field = "content"
|
||||
|
||||
def __iter__(self):
|
||||
iterator = iter(self.dataset)
|
||||
@@ -47,10 +62,10 @@ class ConstantLengthDataset(IterableDataset):
|
||||
while more_examples:
|
||||
buffer, buffer_len = [], 0
|
||||
while True:
|
||||
if buffer_len >= self.input_characters:
|
||||
if buffer_len >= self.max_buffer_size:
|
||||
break
|
||||
try:
|
||||
buffer.append(next(iterator)["content"])
|
||||
buffer.append(next(iterator)[self.content_field])
|
||||
buffer_len += len(buffer[-1])
|
||||
except StopIteration:
|
||||
if self.infinite:
|
||||
@@ -60,7 +75,10 @@ class ConstantLengthDataset(IterableDataset):
|
||||
else:
|
||||
more_examples = False
|
||||
break
|
||||
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
|
||||
if self.tokenized:
|
||||
tokenized_inputs = buffer
|
||||
else:
|
||||
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
|
||||
all_token_ids = []
|
||||
for tokenized_input in tokenized_inputs:
|
||||
all_token_ids.extend(tokenized_input + [self.concat_token_id])
|
||||
@@ -102,8 +120,12 @@ def create_dataloaders(args):
|
||||
train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
|
||||
valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs)
|
||||
train_dataset = ConstantLengthDataset(tokenizer, train_data, infinite=True, seq_length=args.seq_length)
|
||||
valid_dataset = ConstantLengthDataset(tokenizer, valid_data, infinite=False, seq_length=args.seq_length)
|
||||
train_dataset = ConstantLengthDataset(
|
||||
tokenizer, train_data, infinite=True, seq_length=args.seq_length, tokenized=args.tokenized
|
||||
)
|
||||
valid_dataset = ConstantLengthDataset(
|
||||
tokenizer, valid_data, infinite=False, seq_length=args.seq_length, tokenized=args.tokenized
|
||||
)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
|
||||
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
|
||||
return train_dataloader, eval_dataloader
|
||||
|
||||
Reference in New Issue
Block a user