[Longformer] Major Refactor (#5219)

* refactor naming

* add small slow test

* refactor

* refactor naming

* rename selected to extra

* big global attention refactor

* make style

* refactor naming

* save intermed

* refactor functions

* finish function refactor

* fix tests

* fix longformer

* fix longformer

* fix longformer

* fix all tests but one

* finish longformer

* address sams and izs comments

* fix transpose
This commit is contained in:
Patrick von Platen
2020-07-01 17:43:32 +02:00
committed by GitHub
parent e0d58ddb65
commit d697b6ca75
3 changed files with 697 additions and 293 deletions

View File

@@ -811,7 +811,7 @@ class ModelTesterMixin:
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
with torch.no_grad():
_ = model(**inputs_dict)
_ = model(**self._prepare_for_class(inputs_dict, model_class))
global_rng = random.Random()