refactored code a bit and made more generic
This commit is contained in:
@@ -442,7 +442,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
||||
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
|
||||
extra_len = 20
|
||||
gen_tokens_1 = hf.generate_1(tokens, num_beams=4, max_length=extra_len,) # repetition_penalty=10.,
|
||||
gen_tokens = hf.generate(tokens, num_beams=4, max_length=extra_len, do_sample=False) # repetition_penalty=10.,
|
||||
gen_tokens = hf.generate(tokens, num_beams=4, max_length=extra_len + 2, do_sample=False) # repetition_penalty=10.,
|
||||
print("1: {}".format(gen_tokens_1))
|
||||
print("2: {}".format(gen_tokens))
|
||||
ipdb.set_trace()
|
||||
|
||||
@@ -621,7 +621,7 @@ class ModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
model(**inputs_dict)
|
||||
|
||||
def _A_test_lm_head_model_random_generate(self):
|
||||
def test_lm_head_model_random_generate(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict.get(
|
||||
|
||||
Reference in New Issue
Block a user