add do_lower_case arg and adjust model saving for lm finetuning.
This commit is contained in:
@@ -461,6 +461,9 @@ def main():
|
||||
parser.add_argument("--on_memory",
|
||||
action='store_true',
|
||||
help="Whether to load train samples into memory or use disk")
|
||||
parser.add_argument("--do_lower_case",
|
||||
action='store_true',
|
||||
help="Whether to lower case the input text. True for uncased models, False for cased models.")
|
||||
parser.add_argument("--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
@@ -612,12 +615,12 @@ def main():
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Save a trained model
|
||||
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||
if n_gpu > 1:
|
||||
torch.save(model.module.bert.state_dict(), output_model_file)
|
||||
else:
|
||||
torch.save(model.bert.state_dict(), output_model_file)
|
||||
if args.do_train:
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
|
||||
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
|
||||
Reference in New Issue
Block a user