fix vocab size in binarized_data (distil): int16 vs int32
This commit is contained in:
@@ -75,13 +75,17 @@ def main():
|
|||||||
iter += 1
|
iter += 1
|
||||||
if iter % interval == 0:
|
if iter % interval == 0:
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.info(f"{iter} examples processed. - {(end-start)/interval:.2f}s/expl")
|
logger.info(f"{iter} examples processed. - {(end-start):.2f}s/{interval}expl")
|
||||||
start = time.time()
|
start = time.time()
|
||||||
logger.info("Finished binarization")
|
logger.info("Finished binarization")
|
||||||
logger.info(f"{len(data)} examples processed.")
|
logger.info(f"{len(data)} examples processed.")
|
||||||
|
|
||||||
dp_file = f"{args.dump_file}.{args.tokenizer_name}.pickle"
|
dp_file = f"{args.dump_file}.{args.tokenizer_name}.pickle"
|
||||||
|
vocab_size = tokenizer.vocab_size
|
||||||
|
if vocab_size < (1 << 16):
|
||||||
rslt_ = [np.uint16(d) for d in rslt]
|
rslt_ = [np.uint16(d) for d in rslt]
|
||||||
|
else:
|
||||||
|
rslt_ = [np.int32(d) for d in rslt]
|
||||||
random.shuffle(rslt_)
|
random.shuffle(rslt_)
|
||||||
logger.info(f"Dump to {dp_file}")
|
logger.info(f"Dump to {dp_file}")
|
||||||
with open(dp_file, "wb") as handle:
|
with open(dp_file, "wb") as handle:
|
||||||
|
|||||||
Reference in New Issue
Block a user