@@ -115,6 +115,12 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
with io.open(self.ds_config_file[ZERO3], "r", encoding="utf-8") as f:
|
with io.open(self.ds_config_file[ZERO3], "r", encoding="utf-8") as f:
|
||||||
self.ds_config_dict[ZERO3] = json.load(f)
|
self.ds_config_dict[ZERO3] = json.load(f)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
# XXX: Fixme - this is a temporary band-aid since this global variable impacts other tests
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
transformers.integrations._is_deepspeed_zero3_enabled = None
|
||||||
|
|
||||||
def get_config_dict(self, stage):
|
def get_config_dict(self, stage):
|
||||||
""" As the tests modify the dict, always make a copy """
|
""" As the tests modify the dict, always make a copy """
|
||||||
config = deepcopy(self.ds_config_dict[stage])
|
config = deepcopy(self.ds_config_dict[stage])
|
||||||
|
|||||||
Reference in New Issue
Block a user