[Deepspeed] ZeRO-Infinity integration plus config revamp (#11418)
* adding Z-inf * revamp config process * up version requirement * wip * massive rewrite * cleanup * cleanup * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * consistent json commas * act on suggestions * leave this feature for 0.3.16 * style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -213,16 +213,21 @@ if is_torch_available():
|
||||
label_names = kwargs.get("label_names", None)
|
||||
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
|
||||
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
|
||||
if pretrained:
|
||||
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
|
||||
model_init = kwargs.pop("model_init", None)
|
||||
if model_init is not None:
|
||||
model = None
|
||||
else:
|
||||
model = RegressionModel(a=a, b=b, double_output=double_output)
|
||||
if pretrained:
|
||||
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
else:
|
||||
model = RegressionModel(a=a, b=b, double_output=double_output)
|
||||
|
||||
compute_metrics = kwargs.pop("compute_metrics", None)
|
||||
data_collator = kwargs.pop("data_collator", None)
|
||||
optimizers = kwargs.pop("optimizers", (None, None))
|
||||
output_dir = kwargs.pop("output_dir", "./regression")
|
||||
model_init = kwargs.pop("model_init", None)
|
||||
|
||||
args = RegressionTrainingArguments(output_dir, a=a, b=b, **kwargs)
|
||||
return Trainer(
|
||||
|
||||
Reference in New Issue
Block a user