adding jupyter, updating extract features adding simple test file
This commit is contained in:
@@ -37,35 +37,6 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--input_file", default=None, type=str, required=True)
|
||||
parser.add_argument("--vocab_file", default=None, type=str, required=True,
|
||||
help="The vocabulary file that the BERT model was trained on.")
|
||||
parser.add_argument("--output_file", default=None, type=str, required=True)
|
||||
parser.add_argument("--bert_config_file", default=None, type=str, required=True,
|
||||
help="The config json file corresponding to the pre-trained BERT model. "
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--init_checkpoint", default=None, type=str, required=True,
|
||||
help="Initial checkpoint (usually from a pre-trained BERT model).")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
|
||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
|
||||
"than this will be truncated, and sequences shorter than this will be padded.")
|
||||
parser.add_argument("--do_lower_case", default=True, action='store_true',
|
||||
help="Whether to lower case the input text. Should be True for uncased "
|
||||
"models and False for cased models.")
|
||||
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
|
||||
parser.add_argument("--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help = "local_rank for distributed training on gpus")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
class InputExample(object):
|
||||
|
||||
@@ -219,6 +190,35 @@ def read_examples(input_file):
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--input_file", default=None, type=str, required=True)
|
||||
parser.add_argument("--vocab_file", default=None, type=str, required=True,
|
||||
help="The vocabulary file that the BERT model was trained on.")
|
||||
parser.add_argument("--output_file", default=None, type=str, required=True)
|
||||
parser.add_argument("--bert_config_file", default=None, type=str, required=True,
|
||||
help="The config json file corresponding to the pre-trained BERT model. "
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--init_checkpoint", default=None, type=str, required=True,
|
||||
help="Initial checkpoint (usually from a pre-trained BERT model).")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
|
||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
|
||||
"than this will be truncated, and sequences shorter than this will be padded.")
|
||||
parser.add_argument("--do_lower_case", default=True, action='store_true',
|
||||
help="Whether to lower case the input text. Should be True for uncased "
|
||||
"models and False for cased models.")
|
||||
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
|
||||
parser.add_argument("--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help = "local_rank for distributed training on gpus")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
|
||||
Reference in New Issue
Block a user