[Examples/TensorFlow] minor refactoring to allow compatible datasets to work (#22879)
minor refactoring to allow compatible datasets to work.
This commit is contained in:
@@ -69,16 +69,16 @@ def parse_args():
|
||||
|
||||
|
||||
def main(args):
|
||||
wikitext = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train")
|
||||
dataset = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train")
|
||||
|
||||
if args.limit is not None:
|
||||
max_train_samples = min(len(wikitext), args.limit)
|
||||
wikitext = wikitext.select(range(max_train_samples))
|
||||
max_train_samples = min(len(dataset), args.limit)
|
||||
dataset = dataset.select(range(max_train_samples))
|
||||
logger.info(f"Limiting the dataset to {args.limit} entries.")
|
||||
|
||||
def batch_iterator():
|
||||
for i in range(0, len(wikitext), args.batch_size):
|
||||
yield wikitext[i : i + args.batch_size]["text"]
|
||||
for i in range(0, len(dataset), args.batch_size):
|
||||
yield dataset[i : i + args.batch_size]["text"]
|
||||
|
||||
# Prepare the tokenizer.
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
@@ -111,7 +111,7 @@ def main(args):
|
||||
if args.export_to_hub:
|
||||
logger.info("Exporting the trained tokenzier to Hub.")
|
||||
new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer)
|
||||
new_tokenizer.push_to_hub("unigram-tokenizer-wikitext")
|
||||
new_tokenizer.push_to_hub("unigram-tokenizer-dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user