From a12ab0a8dba640730b5353d07ca8893c2d64688f Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Wed, 2 Oct 2019 11:01:55 -0400 Subject: [PATCH] update binarized_data --- examples/distillation/scripts/binarized_data.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/distillation/scripts/binarized_data.py b/examples/distillation/scripts/binarized_data.py index eb4af08b0f..43824e9964 100644 --- a/examples/distillation/scripts/binarized_data.py +++ b/examples/distillation/scripts/binarized_data.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Preprocessing script before training DistilBERT. +Preprocessing script before distillation. """ import argparse import pickle import random import time import numpy as np -from transformers import BertTokenizer, RobertaTokenizer +from transformers import BertTokenizer, RobertaTokenizer, GPT2Tokenizer import logging logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', @@ -32,7 +32,7 @@ def main(): parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).") parser.add_argument('--file_path', type=str, default='data/dump.txt', help='The path to the data.') - parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta']) + parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta', 'gpt2']) parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased', help="The tokenizer to use.") parser.add_argument('--dump_file', type=str, default='data/dump', @@ -43,10 +43,16 @@ def main(): logger.info(f'Loading Tokenizer ({args.tokenizer_name})') if args.tokenizer_type == 'bert': tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name) + bos = tokenizer.special_tokens_map['cls_token'] # `[CLS]` + sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]` elif args.tokenizer_type == 'roberta': tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) - bos = tokenizer.special_tokens_map['bos_token'] # `[CLS]` for bert, `` for roberta - sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]` for bert, `` for roberta + bos = tokenizer.special_tokens_map['cls_token'] # `` + sep = tokenizer.special_tokens_map['sep_token'] # `` + elif args.tokenizer_type == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name) + bos = tokenizer.special_tokens_map['bos_token'] # `<|endoftext|>` + sep = tokenizer.special_tokens_map['eos_token'] # `<|endoftext|>` logger.info(f'Loading text from {args.file_path}') with open(args.file_path, 'r', encoding='utf8') as fp: