removed redundant arg in prepare_inputs (#6614)
* removed redundant arg in prepare_inputs * made same change in prediction_loop
This commit is contained in:
committed by
GitHub
parent
cabfdfafc0
commit
33bf426498
@@ -705,7 +705,7 @@ class Trainer:
|
|||||||
print(output)
|
print(output)
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self, inputs: Dict[str, Union[torch.Tensor, Any]], model: nn.Module
|
self, inputs: Dict[str, Union[torch.Tensor, Any]]
|
||||||
) -> Dict[str, Union[torch.Tensor, Any]]:
|
) -> Dict[str, Union[torch.Tensor, Any]]:
|
||||||
"""
|
"""
|
||||||
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
||||||
@@ -746,7 +746,7 @@ class Trainer:
|
|||||||
return self._training_step(model, inputs, self.optimizer)
|
return self._training_step(model, inputs, self.optimizer)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
inputs = self._prepare_inputs(inputs, model)
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
if self.args.fp16 and _use_native_amp:
|
if self.args.fp16 and _use_native_amp:
|
||||||
with autocast():
|
with autocast():
|
||||||
@@ -1071,7 +1071,7 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
|
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
|
||||||
|
|
||||||
inputs = self._prepare_inputs(inputs, model)
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user