re-format
This commit is contained in:
@@ -1,20 +1,22 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
import random
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from collections import namedtuple
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from collections import namedtuple
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
|
||||||
from pytorch_transformers.modeling_bert import BertForPreTraining
|
from pytorch_transformers.modeling_bert import BertForPreTraining
|
||||||
from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule
|
|
||||||
from pytorch_transformers.tokenization_bert import BertTokenizer
|
from pytorch_transformers.tokenization_bert import BertTokenizer
|
||||||
|
from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule
|
||||||
|
|
||||||
InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next")
|
InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next")
|
||||||
|
|
||||||
@@ -70,16 +72,16 @@ class PregeneratedDataset(Dataset):
|
|||||||
if reduce_memory:
|
if reduce_memory:
|
||||||
self.temp_dir = TemporaryDirectory()
|
self.temp_dir = TemporaryDirectory()
|
||||||
self.working_dir = Path(self.temp_dir.name)
|
self.working_dir = Path(self.temp_dir.name)
|
||||||
input_ids = np.memmap(filename=self.working_dir / 'input_ids.memmap',
|
input_ids = np.memmap(filename=self.working_dir/'input_ids.memmap',
|
||||||
mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
|
mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
|
||||||
input_masks = np.memmap(filename=self.working_dir / 'input_masks.memmap',
|
input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap',
|
||||||
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
||||||
segment_ids = np.memmap(filename=self.working_dir / 'segment_ids.memmap',
|
segment_ids = np.memmap(filename=self.working_dir/'segment_ids.memmap',
|
||||||
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
|
||||||
lm_label_ids = np.memmap(filename=self.working_dir / 'lm_label_ids.memmap',
|
lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap',
|
||||||
shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
|
shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
|
||||||
lm_label_ids[:] = -1
|
lm_label_ids[:] = -1
|
||||||
is_nexts = np.memmap(filename=self.working_dir / 'is_nexts.memmap',
|
is_nexts = np.memmap(filename=self.working_dir/'is_nexts.memmap',
|
||||||
shape=(num_samples,), mode='w+', dtype=np.bool)
|
shape=(num_samples,), mode='w+', dtype=np.bool)
|
||||||
else:
|
else:
|
||||||
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
|
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
|
||||||
@@ -123,8 +125,7 @@ def main():
|
|||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument('--pregenerated_data', type=Path, required=True)
|
parser.add_argument('--pregenerated_data', type=Path, required=True)
|
||||||
parser.add_argument('--output_dir', type=Path, required=True)
|
parser.add_argument('--output_dir', type=Path, required=True)
|
||||||
parser.add_argument("--bert_model", type=str, required=True,
|
parser.add_argument("--bert_model", type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
|
||||||
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
|
||||||
parser.add_argument("--do_lower_case", action="store_true")
|
parser.add_argument("--do_lower_case", action="store_true")
|
||||||
parser.add_argument("--reduce_memory", action="store_true",
|
parser.add_argument("--reduce_memory", action="store_true",
|
||||||
@@ -336,8 +337,7 @@ def main():
|
|||||||
# Save a trained model
|
# Save a trained model
|
||||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||||
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
||||||
model_to_save = model.module if hasattr(model,
|
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||||
'module') else model # Take care of distributed/parallel training
|
|
||||||
model_to_save.save_pretrained(args.output_dir)
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user