load the pretrained weights for encoder-decoder
We currently save the pretrained_weights of the encoder and decoder in two separate directories `encoder` and `decoder`. However, for the `from_pretrained` function to operate with automodels we need to specify the type of model in the path to the weights. The path to the encoder/decoder weights is handled by the `PreTrainedEncoderDecoder` class in the `save_pretrained` function. Sice there is no easy way to infer the type of model that was initialized for the encoder and decoder we add a parameter `model_type` to the function. This is not an ideal solution as it is error prone, and the model type should be carried by the Model classes somehow. This is a temporary fix that should be changed before merging.
This commit is contained in:
committed by
Julien Chaumond
parent
07f4cd73f6
commit
1c71ecc880
@@ -328,6 +328,22 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
return result
|
||||
|
||||
|
||||
def save_model_checkpoints(args, model, tokenizer):
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir, model_type='bert')
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -454,36 +470,30 @@ def main():
|
||||
# Train the model
|
||||
model.to(args.device)
|
||||
if args.do_train:
|
||||
global_step, tr_loss = train(args, model, tokenizer)
|
||||
try:
|
||||
global_step, tr_loss = train(args, model, tokenizer)
|
||||
except KeyboardInterrupt:
|
||||
response = input("You interrupted the training. Do you want to save the model checkpoints? [Y/n]")
|
||||
if response.lower() in ["", "y", "yes"]:
|
||||
save_model_checkpoints(args, model, tokenizer)
|
||||
sys.exit(0)
|
||||
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
|
||||
save_model_checkpoints(args, model, tokenizer)
|
||||
|
||||
# Evaluate the model
|
||||
results = {}
|
||||
if args.do_evaluate:
|
||||
checkpoints = []
|
||||
checkpoints = [args.output_dir]
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
encoder_checkpoint = os.path.join(checkpoint, "encoder")
|
||||
decoder_checkpoint = os.path.join(checkpoint, "decoder")
|
||||
encoder_checkpoint = os.path.join(checkpoint, "bert_encoder")
|
||||
decoder_checkpoint = os.path.join(checkpoint, "bert_decoder")
|
||||
model = PreTrainedEncoderDecoder.from_pretrained(
|
||||
encoder_checkpoint, decoder_checkpoint
|
||||
)
|
||||
model.to(args.device)
|
||||
results = "placeholder"
|
||||
print("model loaded")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user