From 33bf42649884e11873322b3e01446bec67773a2f Mon Sep 17 00:00:00 2001 From: Prajjwal Bhargava Date: Thu, 20 Aug 2020 17:53:35 +0530 Subject: [PATCH] removed redundant arg in prepare_inputs (#6614) * removed redundant arg in prepare_inputs * made same change in prediction_loop --- src/transformers/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index baaf77ade3..bcba5c89c2 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -705,7 +705,7 @@ class Trainer: print(output) def _prepare_inputs( - self, inputs: Dict[str, Union[torch.Tensor, Any]], model: nn.Module + self, inputs: Dict[str, Union[torch.Tensor, Any]] ) -> Dict[str, Union[torch.Tensor, Any]]: """ Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and @@ -746,7 +746,7 @@ class Trainer: return self._training_step(model, inputs, self.optimizer) model.train() - inputs = self._prepare_inputs(inputs, model) + inputs = self._prepare_inputs(inputs) if self.args.fp16 and _use_native_amp: with autocast(): @@ -1071,7 +1071,7 @@ class Trainer: """ has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) - inputs = self._prepare_inputs(inputs, model) + inputs = self._prepare_inputs(inputs) with torch.no_grad(): outputs = model(**inputs)