From c6bf1a400df220ddbe6f74ffd6456d0728d51e4f Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 11 Jul 2019 22:29:08 +0200 Subject: [PATCH] fix test examples et model pretrained --- examples/test_examples.py | 3 ++- pytorch_transformers/modeling_utils.py | 2 +- pytorch_transformers/tests/modeling_utils_test.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/test_examples.py b/examples/test_examples.py index dec59358b8..2e6ed45063 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -56,7 +56,8 @@ class ExamplesTests(unittest.TestCase): "--learning_rate=1e-4", "--max_steps=10", "--warmup_steps=2", - "--overwrite_output_dir"] + "--overwrite_output_dir", + "--seed=42"] model_name = "--model_name=bert-base-uncased" with patch.object(sys, 'argv', testargs + [model_name]): result = run_glue.main() diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index c304e7fdf0..a9445ecad5 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -49,7 +49,7 @@ class PretrainedConfig(object): self.torchscript = kwargs.pop('torchscript', False) @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs): """ Instantiate a PretrainedConfig from a pre-trained model configuration. diff --git a/pytorch_transformers/tests/modeling_utils_test.py b/pytorch_transformers/tests/modeling_utils_test.py index a168c24611..4944f41228 100644 --- a/pytorch_transformers/tests/modeling_utils_test.py +++ b/pytorch_transformers/tests/modeling_utils_test.py @@ -30,6 +30,7 @@ class ModelUtilsTest(unittest.TestCase): self.assertIsNotNone(config) self.assertIsInstance(config, PretrainedConfig) + model = BertModel.from_pretrained(model_name) model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True) self.assertIsNotNone(model) self.assertIsInstance(model, PreTrainedModel)