[Deepspeed] ZeRO-Infinity integration plus config revamp (#11418)

* adding Z-inf

* revamp config process

* up version requirement

* wip

* massive rewrite

* cleanup

* cleanup

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* consistent json commas

* act on suggestions

* leave this feature for 0.3.16

* style

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman
2021-04-26 10:40:32 -07:00
committed by GitHub
parent 0661abc545
commit bc2571e61c
10 changed files with 896 additions and 503 deletions

View File

@@ -213,16 +213,21 @@ if is_torch_available():
label_names = kwargs.get("label_names", None)
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
if pretrained:
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
model = RegressionPreTrainedModel(config)
model_init = kwargs.pop("model_init", None)
if model_init is not None:
model = None
else:
model = RegressionModel(a=a, b=b, double_output=double_output)
if pretrained:
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
model = RegressionPreTrainedModel(config)
else:
model = RegressionModel(a=a, b=b, double_output=double_output)
compute_metrics = kwargs.pop("compute_metrics", None)
data_collator = kwargs.pop("data_collator", None)
optimizers = kwargs.pop("optimizers", (None, None))
output_dir = kwargs.pop("output_dir", "./regression")
model_init = kwargs.pop("model_init", None)
args = RegressionTrainingArguments(output_dir, a=a, b=b, **kwargs)
return Trainer(