Add _mp_fn to run_mae.py for XLA testing (#21551)

Update run_mae.py
This commit is contained in:
steventk-g
2023-02-10 06:53:55 -08:00
committed by GitHub
parent b20147a3c8
commit c88b11c591

View File

@@ -384,5 +384,10 @@ def main():
trainer.create_model_card(**kwargs)
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()