Trainer (#3800)
* doc
* [tests] Add sample files for a regression task
* [HUGE] Trainer
* Feedback from @sshleifer
* Feedback from @thomwolf + logging tweak
* [file_utils] when downloading concurrently, get_from_cache will use the cached file for subsequent processes
* [glue] Use default max_seq_length of 128 like before
* [glue] move DataTrainingArguments around
* [ner] Change interface of InputExample, and align run_{tf,pl}
* Re-align the pl scripts a little bit
* ner
* [ner] Add integration test
* Fix language_modeling with API tweak
* [ci] Tweak loss target
* Don't break console output
* amp.initialize: model must be on right device before
* [multiple-choice] update for Trainer
* Re-align to 827d6d6ef0
This commit is contained in:
@@ -456,6 +456,11 @@ def get_from_cache(
|
||||
lock_path = cache_path + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
# If the download just completed while the lock was activated.
|
||||
if os.path.exists(cache_path) and not force_download:
|
||||
# Even if returning early like here, the lock will be released.
|
||||
return cache_path
|
||||
|
||||
if resume_download:
|
||||
incomplete_path = cache_path + ".incomplete"
|
||||
|
||||
@@ -496,3 +501,50 @@ def get_from_cache(
|
||||
json.dump(meta, meta_file)
|
||||
|
||||
return cache_path
|
||||
|
||||
|
||||
class cached_property(property):
|
||||
"""
|
||||
Descriptor that mimics @property but caches output in member variable.
|
||||
|
||||
From tensorflow_datasets
|
||||
|
||||
Built-in in functools from Python 3.8.
|
||||
"""
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
# See docs.python.org/3/howto/descriptor.html#properties
|
||||
if obj is None:
|
||||
return self
|
||||
if self.fget is None:
|
||||
raise AttributeError("unreadable attribute")
|
||||
attr = "__cached_" + self.fget.__name__
|
||||
cached = getattr(obj, attr, None)
|
||||
if cached is None:
|
||||
cached = self.fget(obj)
|
||||
setattr(obj, attr, cached)
|
||||
return cached
|
||||
|
||||
|
||||
def torch_required(func):
|
||||
# Chose a different decorator name than in tests so it's clear they are not the same.
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if is_torch_available():
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def tf_required(func):
|
||||
# Chose a different decorator name than in tests so it's clear they are not the same.
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if is_tf_available():
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
raise ImportError(f"Method `{func.__name__}` requires TF.")
|
||||
|
||||
return wrapper
|
||||
|
||||
Reference in New Issue
Block a user