[WIP] Tapas v4 (tres) (#9117)
* First commit: adding all files from tapas_v3 * Fix multiple bugs including soft dependency and new structure of the library * Improve testing by adding torch_device to inputs and adding dependency on scatter * Use Python 3 inheritance rather than Python 2 * First draft model cards of base sized models * Remove model cards as they are already on the hub * Fix multiple bugs with integration tests * All model integration tests pass * Remove print statement * Add test for convert_logits_to_predictions method of TapasTokenizer * Incorporate suggestions by Google authors * Fix remaining tests * Change position embeddings sizes to 512 instead of 1024 * Comment out positional embedding sizes * Update PRETRAINED_VOCAB_FILES_MAP and PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES * Added more model names * Fix truncation when no max length is specified * Disable torchscript test * Make style & make quality * Quality * Address CI needs * Test the Masked LM model * Fix the masked LM model * Truncate when overflowing * More much needed docs improvements * Fix some URLs * Some more docs improvements * Test PyTorch scatter * Set to slow + minify * Calm flake8 down * First commit: adding all files from tapas_v3 * Fix multiple bugs including soft dependency and new structure of the library * Improve testing by adding torch_device to inputs and adding dependency on scatter * Use Python 3 inheritance rather than Python 2 * First draft model cards of base sized models * Remove model cards as they are already on the hub * Fix multiple bugs with integration tests * All model integration tests pass * Remove print statement * Add test for convert_logits_to_predictions method of TapasTokenizer * Incorporate suggestions by Google authors * Fix remaining tests * Change position embeddings sizes to 512 instead of 1024 * Comment out positional embedding sizes * Update PRETRAINED_VOCAB_FILES_MAP and PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES * Added more model names * Fix truncation when no max length is specified * Disable torchscript test * Make style & make quality * Quality * Address CI needs * Test the Masked LM model * Fix the masked LM model * Truncate when overflowing * More much needed docs improvements * Fix some URLs * Some more docs improvements * Add add_pooling_layer argument to TapasModel Fix comments by @sgugger and @patrickvonplaten * Fix issue in docs + fix style and quality * Clean up conversion script and add task parameter to TapasConfig * Revert the task parameter of TapasConfig Some minor fixes * Improve conversion script and add test for absolute position embeddings * Improve conversion script and add test for absolute position embeddings * Fix bug with reset_position_index_per_cell arg of the conversion cli * Add notebooks to the examples directory and fix style and quality * Apply suggestions from code review * Move from `nielsr/` to `google/` namespace * Apply Sylvain's comments Co-authored-by: sgugger <sylvain.gugger@gmail.com> Co-authored-by: Rogge Niels <niels.rogge@howest.be> Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
@@ -216,6 +216,29 @@ except ImportError:
|
||||
_tokenizers_available = False
|
||||
|
||||
|
||||
try:
|
||||
import pandas # noqa: F401
|
||||
|
||||
_pandas_available = True
|
||||
|
||||
except ImportError:
|
||||
_pandas_available = False
|
||||
|
||||
|
||||
try:
|
||||
import torch_scatter
|
||||
|
||||
# Check we're not importing a "torch_scatter" directory somewhere
|
||||
_scatter_available = hasattr(torch_scatter, "__version__") and hasattr(torch_scatter, "scatter")
|
||||
if _scatter_available:
|
||||
logger.debug(f"Succesfully imported torch-scatter version {torch_scatter.__version__}")
|
||||
else:
|
||||
logger.debug("Imported a torch_scatter object but this doesn't seem to be the torch-scatter library.")
|
||||
|
||||
except ImportError:
|
||||
_scatter_available = False
|
||||
|
||||
|
||||
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||
# New default cache, shared with the Datasets library
|
||||
hf_cache_home = os.path.expanduser(
|
||||
@@ -325,6 +348,14 @@ def is_in_notebook():
|
||||
return _in_notebook
|
||||
|
||||
|
||||
def is_scatter_available():
|
||||
return _scatter_available
|
||||
|
||||
|
||||
def is_pandas_available():
|
||||
return _pandas_available
|
||||
|
||||
|
||||
def torch_only_method(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
if not _torch_available:
|
||||
@@ -427,6 +458,13 @@ installation page: https://github.com/google/flax and follow the ones that match
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
SCATTER_IMPORT_ERROR = """
|
||||
{0} requires the torch-scatter library but it was not found in your environment. You can install it with pip as
|
||||
explained here: https://github.com/rusty1s/pytorch_scatter.
|
||||
"""
|
||||
|
||||
|
||||
def requires_datasets(obj):
|
||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||
if not is_datasets_available():
|
||||
@@ -481,6 +519,12 @@ def requires_protobuf(obj):
|
||||
raise ImportError(PROTOBUF_IMPORT_ERROR.format(name))
|
||||
|
||||
|
||||
def requires_scatter(obj):
|
||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||
if not is_scatter_available():
|
||||
raise ImportError(SCATTER_IMPORT_ERROR.format(name))
|
||||
|
||||
|
||||
def add_start_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
||||
|
||||
Reference in New Issue
Block a user