* doc

* [tests] Add sample files for a regression task

* [HUGE] Trainer

* Feedback from @sshleifer

* Feedback from @thomwolf + logging tweak

* [file_utils] when downloading concurrently, get_from_cache will use the cached file for subsequent processes

* [glue] Use default max_seq_length of 128 like before

* [glue] move DataTrainingArguments around

* [ner] Change interface of InputExample, and align run_{tf,pl}

* Re-align the pl scripts a little bit

* ner

* [ner] Add integration test

* Fix language_modeling with API tweak

* [ci] Tweak loss target

* Don't break console output

* amp.initialize: model must be on right device before

* [multiple-choice] update for Trainer

* Re-align to 827d6d6ef0
This commit is contained in:
Julien Chaumond
2020-04-21 20:11:56 -04:00
committed by GitHub
parent eb5601b0a5
commit dd9d483d03
41 changed files with 2682 additions and 2567 deletions

View File

@@ -456,6 +456,11 @@ def get_from_cache(
lock_path = cache_path + ".lock"
with FileLock(lock_path):
# If the download just completed while the lock was activated.
if os.path.exists(cache_path) and not force_download:
# Even if returning early like here, the lock will be released.
return cache_path
if resume_download:
incomplete_path = cache_path + ".incomplete"
@@ -496,3 +501,50 @@ def get_from_cache(
json.dump(meta, meta_file)
return cache_path
class cached_property(property):
"""
Descriptor that mimics @property but caches output in member variable.
From tensorflow_datasets
Built-in in functools from Python 3.8.
"""
def __get__(self, obj, objtype=None):
# See docs.python.org/3/howto/descriptor.html#properties
if obj is None:
return self
if self.fget is None:
raise AttributeError("unreadable attribute")
attr = "__cached_" + self.fget.__name__
cached = getattr(obj, attr, None)
if cached is None:
cached = self.fget(obj)
setattr(obj, attr, cached)
return cached
def torch_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_torch_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
return wrapper
def tf_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_tf_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires TF.")
return wrapper