re-format

This commit is contained in:
jamin
2019-08-30 14:05:28 +09:00
parent c8731b9583
commit 2fb9a934b4

View File

@@ -1,20 +1,22 @@
import json
import logging
import random
from argparse import ArgumentParser from argparse import ArgumentParser
from collections import namedtuple
from pathlib import Path from pathlib import Path
import os
import torch
import logging
import json
import random
import numpy as np
from collections import namedtuple
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, RandomSampler from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm from tqdm import tqdm
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
from pytorch_transformers.modeling_bert import BertForPreTraining from pytorch_transformers.modeling_bert import BertForPreTraining
from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule
from pytorch_transformers.tokenization_bert import BertTokenizer from pytorch_transformers.tokenization_bert import BertTokenizer
from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule
InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next") InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next")
@@ -70,16 +72,16 @@ class PregeneratedDataset(Dataset):
if reduce_memory: if reduce_memory:
self.temp_dir = TemporaryDirectory() self.temp_dir = TemporaryDirectory()
self.working_dir = Path(self.temp_dir.name) self.working_dir = Path(self.temp_dir.name)
input_ids = np.memmap(filename=self.working_dir / 'input_ids.memmap', input_ids = np.memmap(filename=self.working_dir/'input_ids.memmap',
mode='w+', dtype=np.int32, shape=(num_samples, seq_len)) mode='w+', dtype=np.int32, shape=(num_samples, seq_len))
input_masks = np.memmap(filename=self.working_dir / 'input_masks.memmap', input_masks = np.memmap(filename=self.working_dir/'input_masks.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.bool) shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
segment_ids = np.memmap(filename=self.working_dir / 'segment_ids.memmap', segment_ids = np.memmap(filename=self.working_dir/'segment_ids.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.bool) shape=(num_samples, seq_len), mode='w+', dtype=np.bool)
lm_label_ids = np.memmap(filename=self.working_dir / 'lm_label_ids.memmap', lm_label_ids = np.memmap(filename=self.working_dir/'lm_label_ids.memmap',
shape=(num_samples, seq_len), mode='w+', dtype=np.int32) shape=(num_samples, seq_len), mode='w+', dtype=np.int32)
lm_label_ids[:] = -1 lm_label_ids[:] = -1
is_nexts = np.memmap(filename=self.working_dir / 'is_nexts.memmap', is_nexts = np.memmap(filename=self.working_dir/'is_nexts.memmap',
shape=(num_samples,), mode='w+', dtype=np.bool) shape=(num_samples,), mode='w+', dtype=np.bool)
else: else:
input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32) input_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.int32)
@@ -123,8 +125,7 @@ def main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('--pregenerated_data', type=Path, required=True) parser.add_argument('--pregenerated_data', type=Path, required=True)
parser.add_argument('--output_dir', type=Path, required=True) parser.add_argument('--output_dir', type=Path, required=True)
parser.add_argument("--bert_model", type=str, required=True, parser.add_argument("--bert_model", type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, "
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
parser.add_argument("--do_lower_case", action="store_true") parser.add_argument("--do_lower_case", action="store_true")
parser.add_argument("--reduce_memory", action="store_true", parser.add_argument("--reduce_memory", action="store_true",
@@ -336,8 +337,7 @@ def main():
# Save a trained model # Save a trained model
if args.local_rank == -1 or torch.distributed.get_rank() == 0: if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logging.info("** ** * Saving fine-tuned model ** ** * ") logging.info("** ** * Saving fine-tuned model ** ** * ")
model_to_save = model.module if hasattr(model, model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir) model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)