load the pretrained weights for encoder-decoder

We currently save the pretrained_weights of the encoder and decoder in
two separate directories `encoder` and `decoder`. However, for the
`from_pretrained` function to operate with automodels we need to
specify the type of model in the path to the weights.

The path to the encoder/decoder weights is handled by the
`PreTrainedEncoderDecoder` class in the `save_pretrained` function. Sice
there is no easy way to infer the type of model that was initialized for
the encoder and decoder we add a parameter `model_type` to the function.
This is not an ideal solution as it is error prone, and the model type
should be carried by the Model classes somehow.

This is a temporary fix that should be changed before merging.
This commit is contained in:
Rémi Louf
2019-10-31 10:16:08 +01:00
committed by Julien Chaumond
parent 07f4cd73f6
commit 1c71ecc880
2 changed files with 49 additions and 30 deletions

View File

@@ -328,6 +328,22 @@ def evaluate(args, model, tokenizer, prefix=""):
return result
def save_model_checkpoints(args, model, tokenizer):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir, model_type='bert')
tokenizer.save_pretrained(args.output_dir)
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
def main():
parser = argparse.ArgumentParser()
@@ -454,36 +470,30 @@ def main():
# Train the model
model.to(args.device)
if args.do_train:
global_step, tr_loss = train(args, model, tokenizer)
try:
global_step, tr_loss = train(args, model, tokenizer)
except KeyboardInterrupt:
response = input("You interrupted the training. Do you want to save the model checkpoints? [Y/n]")
if response.lower() in ["", "y", "yes"]:
save_model_checkpoints(args, model, tokenizer)
sys.exit(0)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
save_model_checkpoints(args, model, tokenizer)
# Evaluate the model
results = {}
if args.do_evaluate:
checkpoints = []
checkpoints = [args.output_dir]
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
encoder_checkpoint = os.path.join(checkpoint, "encoder")
decoder_checkpoint = os.path.join(checkpoint, "decoder")
encoder_checkpoint = os.path.join(checkpoint, "bert_encoder")
decoder_checkpoint = os.path.join(checkpoint, "bert_decoder")
model = PreTrainedEncoderDecoder.from_pretrained(
encoder_checkpoint, decoder_checkpoint
)
model.to(args.device)
results = "placeholder"
print("model loaded")
return results