@@ -4484,6 +4484,7 @@ class ModelTesterMixin:
|
||||
),
|
||||
"position_ids": torch.arange(0, 10, device=torch_device).unsqueeze(0),
|
||||
"labels": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
||||
"use_cache": False,
|
||||
}
|
||||
|
||||
# eager backward
|
||||
|
||||
Reference in New Issue
Block a user