refactored code a bit and made more generic
This commit is contained in:
@@ -621,7 +621,7 @@ class ModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
model(**inputs_dict)
|
||||
|
||||
def _A_test_lm_head_model_random_generate(self):
|
||||
def test_lm_head_model_random_generate(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict.get(
|
||||
|
||||
Reference in New Issue
Block a user