[deepspeed] fix load_best_model test (#17550)
This commit is contained in:
@@ -752,6 +752,8 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
|
|||||||
# must use this setting to get the reload path exercised
|
# must use this setting to get the reload path exercised
|
||||||
ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
|
ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
|
||||||
|
|
||||||
|
with mockenv_context(**self.dist_env_1_gpu):
|
||||||
|
|
||||||
tokenizer = T5Tokenizer.from_pretrained(T5_TINY)
|
tokenizer = T5Tokenizer.from_pretrained(T5_TINY)
|
||||||
model = T5ForConditionalGeneration.from_pretrained(T5_TINY)
|
model = T5ForConditionalGeneration.from_pretrained(T5_TINY)
|
||||||
|
|
||||||
@@ -804,8 +806,6 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
|
|||||||
"deepspeed": ds_config_dict,
|
"deepspeed": ds_config_dict,
|
||||||
}
|
}
|
||||||
|
|
||||||
with mockenv_context(**self.dist_env_1_gpu):
|
|
||||||
|
|
||||||
training_args = TrainingArguments(output_dir, **args_dict)
|
training_args = TrainingArguments(output_dir, **args_dict)
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
|||||||
Reference in New Issue
Block a user