diff --git a/tests/sagemaker/test_multi_node_model_parallel.py b/tests/sagemaker/test_multi_node_model_parallel.py index bca402bcba..3135573653 100644 --- a/tests/sagemaker/test_multi_node_model_parallel.py +++ b/tests/sagemaker/test_multi_node_model_parallel.py @@ -1,4 +1,5 @@ import os +import subprocess import unittest from ast import literal_eval @@ -28,10 +29,23 @@ if is_sagemaker_available(): "instance_type": "ml.p3dn.24xlarge", "results": {"train_runtime": 700, "eval_accuracy": 0.3, "eval_loss": 1.2}, }, + { + "framework": "pytorch", + "script": "run_glue.py", + "model_name_or_path": "roberta-large", + "instance_type": "ml.p3dn.24xlarge", + "results": {"train_runtime": 700, "eval_accuracy": 0.3, "eval_loss": 1.2}, + }, ] ) class MultiNodeTest(unittest.TestCase): def setUp(self): + if self.framework == "pytorch": + subprocess.run( + f"cp ./examples/text-classification/run_glue.py {self.env.test_path}/run_glue.py".split(), + encoding="utf-8", + check=True, + ) assert hasattr(self, "env") def create_estimator(self, instance_count):