@@ -496,7 +496,7 @@ class PersimmonIntegrationTest(unittest.TestCase):
|
||||
model = PersimmonForCausalLM.from_pretrained(
|
||||
"adept/persimmon-8b-chat", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
|
||||
)
|
||||
out = model(torch.tensor([input_ids], device=torch_device)).logits
|
||||
out = model(torch.tensor([input_ids], device=torch_device)).logits.float()
|
||||
|
||||
EXPECTED_MEAN = torch.tensor(
|
||||
[[-11.4726, -11.1495, -11.2694, -11.2223, -10.9452, -11.0663, -11.0031, -11.1028]]
|
||||
|
||||
Reference in New Issue
Block a user