[Examples/TensorFlow] minor refactoring to allow compatible datasets to work (#22879)

minor refactoring to allow compatible datasets to work.
This commit is contained in:
Sayak Paul
2023-04-20 18:21:01 +05:30
committed by GitHub
parent 10dd3a7d1c
commit 4116d1ec75
2 changed files with 21 additions and 12 deletions

View File

@@ -69,16 +69,16 @@ def parse_args():
def main(args):
wikitext = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train")
dataset = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train")
if args.limit is not None:
max_train_samples = min(len(wikitext), args.limit)
wikitext = wikitext.select(range(max_train_samples))
max_train_samples = min(len(dataset), args.limit)
dataset = dataset.select(range(max_train_samples))
logger.info(f"Limiting the dataset to {args.limit} entries.")
def batch_iterator():
for i in range(0, len(wikitext), args.batch_size):
yield wikitext[i : i + args.batch_size]["text"]
for i in range(0, len(dataset), args.batch_size):
yield dataset[i : i + args.batch_size]["text"]
# Prepare the tokenizer.
tokenizer = Tokenizer(Unigram())
@@ -111,7 +111,7 @@ def main(args):
if args.export_to_hub:
logger.info("Exporting the trained tokenzier to Hub.")
new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer)
new_tokenizer.push_to_hub("unigram-tokenizer-wikitext")
new_tokenizer.push_to_hub("unigram-tokenizer-dataset")
if __name__ == "__main__":