27 lines
787 B
Python
27 lines
787 B
Python
from torch.utils.data import Dataset, DataLoader
|
|
import os
|
|
import random
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class WikiTextDataset(Dataset):
|
|
def __init__(self, tokenizer, file='train', directory='wikitext', max_context_length=1024):
|
|
self.max_context_length = max_context_length
|
|
|
|
self.examples = []
|
|
|
|
with open(os.path.join(directory, f"wiki.{file}.raw"), encoding="utf-8") as f:
|
|
text = f.read()
|
|
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
|
|
|
while len(tokenized_text) > max_context_length:
|
|
self.examples.append(tokenized_text[:max_context_length])
|
|
tokenized_text = tokenized_text[max_context_length:]
|
|
|
|
def __len__(self):
|
|
return len(self.examples)
|
|
|
|
def __getitem__(self, item):
|
|
return torch.tensor(self.examples[item])
|