refactored code a bit and made more generic

This commit is contained in:
Patrick von Platen
2020-03-05 14:56:56 +01:00
parent d8e2b3c547
commit c0d9dd3ba9
4 changed files with 28 additions and 16 deletions

View File

@@ -621,7 +621,7 @@ class ModelTesterMixin:
with torch.no_grad():
model(**inputs_dict)
def _A_test_lm_head_model_random_generate(self):
def test_lm_head_model_random_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.get(