refactored code a bit and made more generic
This commit is contained in:
@@ -442,7 +442,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
||||
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
|
||||
extra_len = 20
|
||||
gen_tokens_1 = hf.generate_1(tokens, num_beams=4, max_length=extra_len,) # repetition_penalty=10.,
|
||||
gen_tokens = hf.generate(tokens, num_beams=4, max_length=extra_len, do_sample=False) # repetition_penalty=10.,
|
||||
gen_tokens = hf.generate(tokens, num_beams=4, max_length=extra_len + 2, do_sample=False) # repetition_penalty=10.,
|
||||
print("1: {}".format(gen_tokens_1))
|
||||
print("2: {}".format(gen_tokens))
|
||||
ipdb.set_trace()
|
||||
|
||||
Reference in New Issue
Block a user