adding jupyter, updating extract features adding simple test file

This commit is contained in:
thomwolf
2018-11-02 14:25:21 +01:00
parent 844b2f0e6f
commit c9690e57f8
3 changed files with 318 additions and 29 deletions

View File

@@ -37,35 +37,6 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa
level = logging.INFO)
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--input_file", default=None, type=str, required=True)
parser.add_argument("--vocab_file", default=None, type=str, required=True,
help="The vocabulary file that the BERT model was trained on.")
parser.add_argument("--output_file", default=None, type=str, required=True)
parser.add_argument("--bert_config_file", default=None, type=str, required=True,
help="The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture.")
parser.add_argument("--init_checkpoint", default=None, type=str, required=True,
help="Initial checkpoint (usually from a pre-trained BERT model).")
## Other parameters
parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
"than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--do_lower_case", default=True, action='store_true',
help="Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
parser.add_argument("--local_rank",
type=int,
default=-1,
help = "local_rank for distributed training on gpus")
args = parser.parse_args()
class InputExample(object):
@@ -219,6 +190,35 @@ def read_examples(input_file):
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--input_file", default=None, type=str, required=True)
parser.add_argument("--vocab_file", default=None, type=str, required=True,
help="The vocabulary file that the BERT model was trained on.")
parser.add_argument("--output_file", default=None, type=str, required=True)
parser.add_argument("--bert_config_file", default=None, type=str, required=True,
help="The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture.")
parser.add_argument("--init_checkpoint", default=None, type=str, required=True,
help="Initial checkpoint (usually from a pre-trained BERT model).")
## Other parameters
parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
"than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--do_lower_case", default=True, action='store_true',
help="Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
parser.add_argument("--local_rank",
type=int,
default=-1,
help = "local_rank for distributed training on gpus")
args = parser.parse_args()
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()