@@ -20,6 +20,7 @@ import os
|
||||
import re
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -27,7 +28,7 @@ from collections.abc import Mapping
|
||||
from distutils.util import strtobool
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Union
|
||||
from typing import Iterator, List, Union
|
||||
from unittest import mock
|
||||
|
||||
from transformers import logging as transformers_logging
|
||||
@@ -1561,3 +1562,25 @@ def to_2tuple(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return (x, x)
|
||||
|
||||
|
||||
# These utils relate to ensuring the right error message is received when running scripts
|
||||
class SubprocessCallException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def run_command(command: List[str], return_stdout=False):
|
||||
"""
|
||||
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
|
||||
if an error occured while running `command`
|
||||
"""
|
||||
try:
|
||||
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
|
||||
if return_stdout:
|
||||
if hasattr(output, "decode"):
|
||||
output = output.decode("utf-8")
|
||||
return output
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise SubprocessCallException(
|
||||
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
|
||||
) from e
|
||||
|
||||
Reference in New Issue
Block a user