[WIP] Tapas v4 (tres) (#9117)

* First commit: adding all files from tapas_v3

* Fix multiple bugs including soft dependency and new structure of the library

* Improve testing by adding torch_device to inputs and adding dependency on scatter

* Use Python 3 inheritance rather than Python 2

* First draft model cards of base sized models

* Remove model cards as they are already on the hub

* Fix multiple bugs with integration tests

* All model integration tests pass

* Remove print statement

* Add test for convert_logits_to_predictions method of TapasTokenizer

* Incorporate suggestions by Google authors

* Fix remaining tests

* Change position embeddings sizes to 512 instead of 1024

* Comment out positional embedding sizes

* Update PRETRAINED_VOCAB_FILES_MAP and PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

* Added more model names

* Fix truncation when no max length is specified

* Disable torchscript test

* Make style & make quality

* Quality

* Address CI needs

* Test the Masked LM model

* Fix the masked LM model

* Truncate when overflowing

* More much needed docs improvements

* Fix some URLs

* Some more docs improvements

* Test PyTorch scatter

* Set to slow + minify

* Calm flake8 down

* First commit: adding all files from tapas_v3

* Fix multiple bugs including soft dependency and new structure of the library

* Improve testing by adding torch_device to inputs and adding dependency on scatter

* Use Python 3 inheritance rather than Python 2

* First draft model cards of base sized models

* Remove model cards as they are already on the hub

* Fix multiple bugs with integration tests

* All model integration tests pass

* Remove print statement

* Add test for convert_logits_to_predictions method of TapasTokenizer

* Incorporate suggestions by Google authors

* Fix remaining tests

* Change position embeddings sizes to 512 instead of 1024

* Comment out positional embedding sizes

* Update PRETRAINED_VOCAB_FILES_MAP and PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

* Added more model names

* Fix truncation when no max length is specified

* Disable torchscript test

* Make style & make quality

* Quality

* Address CI needs

* Test the Masked LM model

* Fix the masked LM model

* Truncate when overflowing

* More much needed docs improvements

* Fix some URLs

* Some more docs improvements

* Add add_pooling_layer argument to TapasModel

Fix comments by @sgugger and @patrickvonplaten

* Fix issue in docs + fix style and quality

* Clean up conversion script and add task parameter to TapasConfig

* Revert the task parameter of TapasConfig

Some minor fixes

* Improve conversion script and add test for absolute position embeddings

* Improve conversion script and add test for absolute position embeddings

* Fix bug with reset_position_index_per_cell arg of the conversion cli

* Add notebooks to the examples directory and fix style and quality

* Apply suggestions from code review

* Move from `nielsr/` to `google/` namespace

* Apply Sylvain's comments

Co-authored-by: sgugger <sylvain.gugger@gmail.com>

Co-authored-by: Rogge Niels <niels.rogge@howest.be>
Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
NielsRogge
2020-12-15 23:08:49 +01:00
committed by GitHub
parent ad895af98d
commit 1551e2dc6d
22 changed files with 8497 additions and 78 deletions

View File

@@ -216,6 +216,29 @@ except ImportError:
_tokenizers_available = False
try:
import pandas # noqa: F401
_pandas_available = True
except ImportError:
_pandas_available = False
try:
import torch_scatter
# Check we're not importing a "torch_scatter" directory somewhere
_scatter_available = hasattr(torch_scatter, "__version__") and hasattr(torch_scatter, "scatter")
if _scatter_available:
logger.debug(f"Succesfully imported torch-scatter version {torch_scatter.__version__}")
else:
logger.debug("Imported a torch_scatter object but this doesn't seem to be the torch-scatter library.")
except ImportError:
_scatter_available = False
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
# New default cache, shared with the Datasets library
hf_cache_home = os.path.expanduser(
@@ -325,6 +348,14 @@ def is_in_notebook():
return _in_notebook
def is_scatter_available():
return _scatter_available
def is_pandas_available():
return _pandas_available
def torch_only_method(fn):
def wrapper(*args, **kwargs):
if not _torch_available:
@@ -427,6 +458,13 @@ installation page: https://github.com/google/flax and follow the ones that match
"""
# docstyle-ignore
SCATTER_IMPORT_ERROR = """
{0} requires the torch-scatter library but it was not found in your environment. You can install it with pip as
explained here: https://github.com/rusty1s/pytorch_scatter.
"""
def requires_datasets(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_datasets_available():
@@ -481,6 +519,12 @@ def requires_protobuf(obj):
raise ImportError(PROTOBUF_IMPORT_ERROR.format(name))
def requires_scatter(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_scatter_available():
raise ImportError(SCATTER_IMPORT_ERROR.format(name))
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")