Better filtering of the model outputs in Trainer (#8633)
* Better filtering of the model outputs in Trainer * Fix examples tests * Add test for Lysandre
This commit is contained in:
@@ -44,6 +44,8 @@ if is_torch_available():
|
||||
DataCollatorForLanguageModeling,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
LineByLineTextDataset,
|
||||
PreTrainedModel,
|
||||
TextDataset,
|
||||
@@ -73,6 +75,18 @@ class RegressionDataset:
|
||||
return result
|
||||
|
||||
|
||||
class RepeatDataset:
|
||||
def __init__(self, x, length=64):
|
||||
self.x = x
|
||||
self.length = length
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, i):
|
||||
return {"input_ids": self.x, "labels": self.x}
|
||||
|
||||
|
||||
class DynamicShapesDataset:
|
||||
def __init__(self, length=64, seed=42, batch_size=8):
|
||||
self.length = length
|
||||
@@ -136,6 +150,20 @@ if is_torch_available():
|
||||
loss = torch.nn.functional.mse_loss(y, labels)
|
||||
return (loss, y, y) if self.double_output else (loss, y)
|
||||
|
||||
class RegressionDictModel(torch.nn.Module):
|
||||
def __init__(self, a=0, b=0):
|
||||
super().__init__()
|
||||
self.a = torch.nn.Parameter(torch.tensor(a).float())
|
||||
self.b = torch.nn.Parameter(torch.tensor(b).float())
|
||||
self.config = None
|
||||
|
||||
def forward(self, input_x=None, labels=None, **kwargs):
|
||||
y = input_x * self.a + self.b
|
||||
result = {"output": y}
|
||||
if labels is not None:
|
||||
result["loss"] = torch.nn.functional.mse_loss(y, labels)
|
||||
return result
|
||||
|
||||
class RegressionPreTrainedModel(PreTrainedModel):
|
||||
config_class = RegressionModelConfig
|
||||
base_model_prefix = "regression"
|
||||
@@ -236,6 +264,33 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
metrics = trainer.evaluate()
|
||||
self.assertEqual(metrics[metric], best_value)
|
||||
|
||||
def test_trainer_works_with_dict(self):
|
||||
# Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
|
||||
# anything.
|
||||
train_dataset = RegressionDataset()
|
||||
eval_dataset = RegressionDataset()
|
||||
model = RegressionDictModel()
|
||||
args = TrainingArguments("./regression")
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||
trainer.train()
|
||||
_ = trainer.evaluate()
|
||||
_ = trainer.predict(eval_dataset)
|
||||
|
||||
def test_evaluation_with_keys_to_drop(self):
|
||||
config = GPT2Config(vocab_size=100, n_positions=128, n_ctx=128, n_embd=32, n_layer=3, n_head=4)
|
||||
tiny_gpt2 = GPT2LMHeadModel(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
eval_dataset = RepeatDataset(x)
|
||||
args = TrainingArguments("./test")
|
||||
trainer = Trainer(tiny_gpt2, args, eval_dataset=eval_dataset)
|
||||
# By default the past_key_values are removed
|
||||
result = trainer.predict(eval_dataset)
|
||||
self.assertTrue(isinstance(result.predictions, np.ndarray))
|
||||
# We can still get them by setting ignore_keys to []
|
||||
result = trainer.predict(eval_dataset, ignore_keys=[])
|
||||
self.assertTrue(isinstance(result.predictions, tuple))
|
||||
self.assertEqual(len(result.predictions), 2)
|
||||
|
||||
def test_training_arguments_are_left_untouched(self):
|
||||
trainer = get_regression_trainer()
|
||||
trainer.train()
|
||||
|
||||
Reference in New Issue
Block a user