Make forward asynchrone to avoid long computation timing out.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
committed by
Lysandre Debut
parent
6e6c8c52ed
commit
908cd5ea27
@@ -28,8 +28,9 @@ from . import __version__
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
if os.environ.get("USE_TORCH", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \
|
||||
os.environ.get("USE_TF", 'AUTO').upper() not in ("1", "ON", "YES"):
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
|
||||
import torch
|
||||
|
||||
_torch_available = True # pylint: disable=invalid-name
|
||||
@@ -41,8 +42,10 @@ except ImportError:
|
||||
_torch_available = False # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
if os.environ.get("USE_TF", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \
|
||||
os.environ.get("USE_TORCH", 'AUTO').upper() not in ("1", "ON", "YES"):
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
|
||||
if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
|
||||
import tensorflow as tf
|
||||
|
||||
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
|
||||
|
||||
Reference in New Issue
Block a user