[style] consistent nn. and nn.functional: part 3 tests (#12155)

* consistent nn. and nn.functional: p3 templates

* restore
This commit is contained in:
Stas Bekman
2021-06-14 12:18:22 -07:00
committed by GitHub
parent d9c0d08f9a
commit 372ab9cd6d
14 changed files with 93 additions and 81 deletions

View File

@@ -44,6 +44,7 @@ from transformers.testing_utils import (
if is_torch_available():
import numpy as np
import torch
from torch import nn
from transformers import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -1150,10 +1151,10 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Embedding, AdaptiveEmbedding))
model.set_input_embeddings(torch.nn.Embedding(10, 10))
self.assertIsInstance(model.get_input_embeddings(), (nn.Embedding, AdaptiveEmbedding))
model.set_input_embeddings(nn.Embedding(10, 10))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_correct_missing_keys(self):
if not self.test_missing_keys:
@@ -1337,7 +1338,7 @@ class ModelTesterMixin:
model.eval()
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
model = nn.DataParallel(model)
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class))