diff --git a/examples/pytorch/image-pretraining/run_mae.py b/examples/pytorch/image-pretraining/run_mae.py index 80d55c3025..f3448a7753 100644 --- a/examples/pytorch/image-pretraining/run_mae.py +++ b/examples/pytorch/image-pretraining/run_mae.py @@ -384,5 +384,10 @@ def main(): trainer.create_model_card(**kwargs) +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + if __name__ == "__main__": main()