Funnel transformer (#6908)

* Initial model

* Fix upsampling

* Add special cls token id and test

* Formatting

* Test and fist FunnelTokenizerFast

* Common tests

* Fix the check_repo script and document Funnel

* Doc fixes

* Add all models

* Write doc

* Fix test

* Initial model

* Fix upsampling

* Add special cls token id and test

* Formatting

* Test and fist FunnelTokenizerFast

* Common tests

* Fix the check_repo script and document Funnel

* Doc fixes

* Add all models

* Write doc

* Fix test

* Fix copyright

* Forgot some layers can be repeated

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/modeling_funnel.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Address review comments

* Update src/transformers/modeling_funnel.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Address review comments

* Update src/transformers/modeling_funnel.py

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>

* Slow integration test

* Make small integration test

* Formatting

* Add checkpoint and separate classification head

* Formatting

* Expand list, fix link and add in pretrained models

* Styling

* Add the model in all summaries

* Typo fixes

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Sylvain Gugger
2020-09-08 08:08:08 -04:00
committed by GitHub
parent 25afb4ea50
commit d155b38d6e
18 changed files with 3208 additions and 405 deletions

View File

@@ -141,18 +141,20 @@ def get_model_doc_files():
# for the all_model_classes variable.
def find_tested_models(test_file):
""" Parse the content of test_file to detect what's in all_model_classes"""
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
with open(os.path.join(PATH_TO_TESTS, test_file)) as f:
content = f.read()
all_models = re.search(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
# Check with one less parenthesis
if all_models is None:
all_models = re.search(r"all_model_classes\s+=\s+\(([^\)]*)\)", content)
if all_models is not None:
if len(all_models) == 0:
all_models = re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content)
if len(all_models) > 0:
model_tested = []
for line in all_models.groups()[0].split(","):
name = line.strip()
if len(name) > 0:
model_tested.append(name)
for entry in all_models:
for line in entry.split(","):
name = line.strip()
if len(name) > 0:
model_tested.append(name)
return model_tested