Funnel transformer (#6908)
* Initial model * Fix upsampling * Add special cls token id and test * Formatting * Test and fist FunnelTokenizerFast * Common tests * Fix the check_repo script and document Funnel * Doc fixes * Add all models * Write doc * Fix test * Initial model * Fix upsampling * Add special cls token id and test * Formatting * Test and fist FunnelTokenizerFast * Common tests * Fix the check_repo script and document Funnel * Doc fixes * Add all models * Write doc * Fix test * Fix copyright * Forgot some layers can be repeated * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/modeling_funnel.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Address review comments * Update src/transformers/modeling_funnel.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Address review comments * Update src/transformers/modeling_funnel.py Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * Slow integration test * Make small integration test * Formatting * Add checkpoint and separate classification head * Formatting * Expand list, fix link and add in pretrained models * Styling * Add the model in all summaries * Typo fixes Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -15,6 +15,12 @@ def convert_command_factory(args: Namespace):
|
||||
)
|
||||
|
||||
|
||||
IMPORT_ERROR_MESSAGE = """transformers can only be used from the commandline to convert TensorFlow models in PyTorch,
|
||||
In that case, it requires TensorFlow to be installed. Please see
|
||||
https://www.tensorflow.org/install/ for installation instructions.
|
||||
"""
|
||||
|
||||
|
||||
class ConvertCommand(BaseTransformersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
@@ -69,12 +75,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
convert_tf_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
msg = (
|
||||
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
raise ImportError(IMPORT_ERROR_MESSAGE)
|
||||
|
||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "bert":
|
||||
@@ -83,12 +84,16 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
convert_tf_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
msg = (
|
||||
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
raise ImportError(IMPORT_ERROR_MESSAGE)
|
||||
|
||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "funnel":
|
||||
try:
|
||||
from transformers.convert_funnel_original_tf_checkpoint_to_pytorch import (
|
||||
convert_tf_checkpoint_to_pytorch,
|
||||
)
|
||||
raise ImportError(msg)
|
||||
except ImportError:
|
||||
raise ImportError(IMPORT_ERROR_MESSAGE)
|
||||
|
||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "gpt":
|
||||
@@ -103,12 +108,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
convert_transfo_xl_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
msg = (
|
||||
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
raise ImportError(IMPORT_ERROR_MESSAGE)
|
||||
|
||||
if "ckpt" in self._tf_checkpoint.lower():
|
||||
TF_CHECKPOINT = self._tf_checkpoint
|
||||
@@ -125,12 +125,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
convert_gpt2_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
msg = (
|
||||
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
raise ImportError(IMPORT_ERROR_MESSAGE)
|
||||
|
||||
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "xlnet":
|
||||
@@ -139,12 +134,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
convert_xlnet_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
msg = (
|
||||
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
raise ImportError(IMPORT_ERROR_MESSAGE)
|
||||
|
||||
convert_xlnet_checkpoint_to_pytorch(
|
||||
self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
|
||||
|
||||
Reference in New Issue
Block a user