fix errors for lm_finetuning examples
This commit is contained in:
@@ -325,15 +325,16 @@ def main():
|
|||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
# Save a trained model
|
# Save a trained model
|
||||||
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
if torch.distributed.get_rank() == 0:
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
||||||
|
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||||
|
|
||||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||||
|
|
||||||
torch.save(model_to_save.state_dict(), output_model_file)
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
model_to_save.config.to_json_file(output_config_file)
|
model_to_save.config.to_json_file(output_config_file)
|
||||||
tokenizer.save_vocabulary(args.output_dir)
|
tokenizer.save_vocabulary(args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user