add do_lower_case arg and adjust model saving for lm finetuning.
This commit is contained in:
@@ -461,6 +461,9 @@ def main():
|
|||||||
parser.add_argument("--on_memory",
|
parser.add_argument("--on_memory",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="Whether to load train samples into memory or use disk")
|
help="Whether to load train samples into memory or use disk")
|
||||||
|
parser.add_argument("--do_lower_case",
|
||||||
|
action='store_true',
|
||||||
|
help="Whether to lower case the input text. True for uncased models, False for cased models.")
|
||||||
parser.add_argument("--local_rank",
|
parser.add_argument("--local_rank",
|
||||||
type=int,
|
type=int,
|
||||||
default=-1,
|
default=-1,
|
||||||
@@ -612,12 +615,12 @@ def main():
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
|
# Save a trained model
|
||||||
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
||||||
|
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||||
if n_gpu > 1:
|
if args.do_train:
|
||||||
torch.save(model.module.bert.state_dict(), output_model_file)
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
else:
|
|
||||||
torch.save(model.bert.state_dict(), output_model_file)
|
|
||||||
|
|
||||||
|
|
||||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||||
|
|||||||
Reference in New Issue
Block a user