[Longformer] Major Refactor (#5219)
* refactor naming * add small slow test * refactor * refactor naming * rename selected to extra * big global attention refactor * make style * refactor naming * save intermed * refactor functions * finish function refactor * fix tests * fix longformer * fix longformer * fix longformer * fix all tests but one * finish longformer * address sams and izs comments * fix transpose
This commit is contained in:
committed by
GitHub
parent
e0d58ddb65
commit
d697b6ca75
@@ -811,7 +811,7 @@ class ModelTesterMixin:
|
||||
# Wrap model in nn.DataParallel
|
||||
model = torch.nn.DataParallel(model)
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
Reference in New Issue
Block a user