diff --git a/README.md b/README.md index afb351cdd8..4807edc8b9 100644 --- a/README.md +++ b/README.md @@ -102,10 +102,6 @@ Let's see how we can use `BertModel` to encode our inputs in hidden-states: # Load pre-trained model (weights) model = BertModel.from_pretrained('bert-base-uncased') -# Set the model in evaluation mode to desactivate the DropOut modules -# This is IMPORTANT to have reproductible results during evaluation! -model.eval() - # If you have a GPU, put everything on cuda tokens_tensor = tokens_tensor.to('cuda') segments_tensors = segments_tensors.to('cuda') @@ -129,7 +125,6 @@ And how to use `BertForMaskedLM` to predict a masked token: ```python # Load pre-trained model (weights) model = BertForMaskedLM.from_pretrained('bert-base-uncased') -model.eval() # If you have a GPU, put everything on cuda tokens_tensor = tokens_tensor.to('cuda') @@ -178,10 +173,6 @@ Let's see how to use `GPT2LMHeadModel` to generate the next token following our # Load pre-trained model (weights) model = GPT2LMHeadModel.from_pretrained('gpt2') -# Set the model in evaluation mode to desactivate the DropOut modules -# This is IMPORTANT to have reproductible results during evaluation! -model.eval() - # If you have a GPU, put everything on cuda tokens_tensor = tokens_tensor.to('cuda') model.to('cuda') diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 542c70b223..324cdc17c9 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -306,7 +306,10 @@ class PreTrainedModel(nn.Module): @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): - r""" Instantiate a PretrainedConfig from a pre-trained model configuration. + r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated) + To train the model, you should first set it back in training mode with `model.train()` Params: **pretrained_model_name_or_path**: either: @@ -460,6 +463,9 @@ class PreTrainedModel(nn.Module): if hasattr(model, 'tie_weights'): model.tie_weights() # make sure word embedding weights are still tied + # Set model in evaluation mode to desactivate DropOut modules by default + model.eval() + if output_loading_info: loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} return model, loading_info