[DPR] Correct init (#13796)
* update * add to docs and init * make fix-copies
This commit is contained in:
committed by
GitHub
parent
44eb8bdeea
commit
41436d3dfb
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import DPRConfig, is_torch_available
|
||||
@@ -213,6 +214,19 @@ class DPRModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reader(*config_and_inputs)
|
||||
|
||||
def test_init_changed_config(self):
|
||||
config = self.model_tester.prepare_config_and_inputs()[0]
|
||||
|
||||
model = DPRQuestionEncoder(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
model.save_pretrained(tmp_dirname)
|
||||
model = DPRQuestionEncoder.from_pretrained(tmp_dirname, projection_dim=512)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
||||
Reference in New Issue
Block a user