Fix Trainer tests in a multiGPU env (#7458)
This commit is contained in:
@@ -109,12 +109,15 @@ if is_torch_available():
|
||||
loss = torch.nn.functional.mse_loss(y, labels)
|
||||
return (loss, y, y) if self.double_output else (loss, y)
|
||||
|
||||
def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, **kwargs):
|
||||
def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, pretrained=True, **kwargs):
|
||||
label_names = kwargs.get("label_names", None)
|
||||
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
|
||||
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
|
||||
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
if pretrained:
|
||||
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
else:
|
||||
model = RegressionModel(a=a, b=b, double_output=double_output)
|
||||
compute_metrics = kwargs.pop("compute_metrics", None)
|
||||
data_collator = kwargs.pop("data_collator", None)
|
||||
optimizers = kwargs.pop("optimizers", (None, None))
|
||||
@@ -178,6 +181,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
best_model = RegressionModel()
|
||||
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
|
||||
best_model.load_state_dict(state_dict)
|
||||
best_model.to(trainer.args.device)
|
||||
self.assertTrue(torch.allclose(best_model.a, trainer.model.a))
|
||||
self.assertTrue(torch.allclose(best_model.b, trainer.model.b))
|
||||
|
||||
@@ -360,8 +364,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
|
||||
# With a regular model that is not a PreTrainedModel
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
|
||||
trainer.model = RegressionModel()
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5, pretrained=False)
|
||||
trainer.train()
|
||||
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
|
||||
|
||||
@@ -426,8 +429,8 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
eval_steps=5,
|
||||
evaluation_strategy="steps",
|
||||
load_best_model_at_end=True,
|
||||
pretrained=False,
|
||||
)
|
||||
trainer.model = RegressionModel(a=1.5, b=2.5)
|
||||
self.assertFalse(trainer.args.greater_is_better)
|
||||
trainer.train()
|
||||
self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False)
|
||||
|
||||
Reference in New Issue
Block a user