From c88b11c591f2ad5371cad4ca2da7d6ee16bcf0e8 Mon Sep 17 00:00:00 2001 From: steventk-g <107513673+steventk-g@users.noreply.github.com> Date: Fri, 10 Feb 2023 06:53:55 -0800 Subject: [PATCH] Add _mp_fn to run_mae.py for XLA testing (#21551) Update run_mae.py --- examples/pytorch/image-pretraining/run_mae.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/pytorch/image-pretraining/run_mae.py b/examples/pytorch/image-pretraining/run_mae.py index 80d55c3025..f3448a7753 100644 --- a/examples/pytorch/image-pretraining/run_mae.py +++ b/examples/pytorch/image-pretraining/run_mae.py @@ -384,5 +384,10 @@ def main(): trainer.create_model_card(**kwargs) +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + if __name__ == "__main__": main()