Funnel transformer (#6908)
* Initial model * Fix upsampling * Add special cls token id and test * Formatting * Test and fist FunnelTokenizerFast * Common tests * Fix the check_repo script and document Funnel * Doc fixes * Add all models * Write doc * Fix test * Initial model * Fix upsampling * Add special cls token id and test * Formatting * Test and fist FunnelTokenizerFast * Common tests * Fix the check_repo script and document Funnel * Doc fixes * Add all models * Write doc * Fix test * Fix copyright * Forgot some layers can be repeated * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/modeling_funnel.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Address review comments * Update src/transformers/modeling_funnel.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Address review comments * Update src/transformers/modeling_funnel.py Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * Slow integration test * Make small integration test * Formatting * Add checkpoint and separate classification head * Formatting * Expand list, fix link and add in pretrained models * Styling * Add the model in all summaries * Typo fixes Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -141,18 +141,20 @@ def get_model_doc_files():
|
||||
# for the all_model_classes variable.
|
||||
def find_tested_models(test_file):
|
||||
""" Parse the content of test_file to detect what's in all_model_classes"""
|
||||
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
|
||||
with open(os.path.join(PATH_TO_TESTS, test_file)) as f:
|
||||
content = f.read()
|
||||
all_models = re.search(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
|
||||
all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
|
||||
# Check with one less parenthesis
|
||||
if all_models is None:
|
||||
all_models = re.search(r"all_model_classes\s+=\s+\(([^\)]*)\)", content)
|
||||
if all_models is not None:
|
||||
if len(all_models) == 0:
|
||||
all_models = re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content)
|
||||
if len(all_models) > 0:
|
||||
model_tested = []
|
||||
for line in all_models.groups()[0].split(","):
|
||||
name = line.strip()
|
||||
if len(name) > 0:
|
||||
model_tested.append(name)
|
||||
for entry in all_models:
|
||||
for line in entry.split(","):
|
||||
name = line.strip()
|
||||
if len(name) > 0:
|
||||
model_tested.append(name)
|
||||
return model_tested
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user