Funnel transformer (#6908)

* Initial model

* Fix upsampling

* Add special cls token id and test

* Formatting

* Test and fist FunnelTokenizerFast

* Common tests

* Fix the check_repo script and document Funnel

* Doc fixes

* Add all models

* Write doc

* Fix test

* Initial model

* Fix upsampling

* Add special cls token id and test

* Formatting

* Test and fist FunnelTokenizerFast

* Common tests

* Fix the check_repo script and document Funnel

* Doc fixes

* Add all models

* Write doc

* Fix test

* Fix copyright

* Forgot some layers can be repeated

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/modeling_funnel.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Address review comments

* Update src/transformers/modeling_funnel.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Address review comments

* Update src/transformers/modeling_funnel.py

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>

* Slow integration test

* Make small integration test

* Formatting

* Add checkpoint and separate classification head

* Formatting

* Expand list, fix link and add in pretrained models

* Styling

* Add the model in all summaries

* Typo fixes

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Sylvain Gugger
2020-09-08 08:08:08 -04:00
committed by GitHub
parent 25afb4ea50
commit d155b38d6e
18 changed files with 3208 additions and 405 deletions

View File

@@ -15,6 +15,12 @@ def convert_command_factory(args: Namespace):
)
IMPORT_ERROR_MESSAGE = """transformers can only be used from the commandline to convert TensorFlow models in PyTorch,
In that case, it requires TensorFlow to be installed. Please see
https://www.tensorflow.org/install/ for installation instructions.
"""
class ConvertCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
@@ -69,12 +75,7 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_tf_checkpoint_to_pytorch,
)
except ImportError:
msg = (
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise ImportError(msg)
raise ImportError(IMPORT_ERROR_MESSAGE)
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "bert":
@@ -83,12 +84,16 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_tf_checkpoint_to_pytorch,
)
except ImportError:
msg = (
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
raise ImportError(IMPORT_ERROR_MESSAGE)
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "funnel":
try:
from transformers.convert_funnel_original_tf_checkpoint_to_pytorch import (
convert_tf_checkpoint_to_pytorch,
)
raise ImportError(msg)
except ImportError:
raise ImportError(IMPORT_ERROR_MESSAGE)
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "gpt":
@@ -103,12 +108,7 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_transfo_xl_checkpoint_to_pytorch,
)
except ImportError:
msg = (
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise ImportError(msg)
raise ImportError(IMPORT_ERROR_MESSAGE)
if "ckpt" in self._tf_checkpoint.lower():
TF_CHECKPOINT = self._tf_checkpoint
@@ -125,12 +125,7 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_gpt2_checkpoint_to_pytorch,
)
except ImportError:
msg = (
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise ImportError(msg)
raise ImportError(IMPORT_ERROR_MESSAGE)
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "xlnet":
@@ -139,12 +134,7 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_xlnet_checkpoint_to_pytorch,
)
except ImportError:
msg = (
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise ImportError(msg)
raise ImportError(IMPORT_ERROR_MESSAGE)
convert_xlnet_checkpoint_to_pytorch(
self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name