[video processor] support torchcodec and decrease cuda memory usage (#38880)
* don't move the whole video to GPU * add torchcodec * add tests * make style * instrucblip as well * consistency * Update src/transformers/utils/import_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/utils/import_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/video_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
committed by
GitHub
parent
11d0feacce
commit
e212ff9e6a
@@ -94,12 +94,18 @@ class InstructBlipVideoVideoProcessor(BaseVideoProcessor):
|
|||||||
fps: Optional[int] = None,
|
fps: Optional[int] = None,
|
||||||
num_frames: Optional[int] = None,
|
num_frames: Optional[int] = None,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
device: Optional["torch.Tensor"] = None,
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
if do_sample_frames:
|
if do_sample_frames:
|
||||||
videos = [
|
videos = [
|
||||||
self.sample_frames(video, metadata, num_frames, fps) for video, metadata in zip(videos, video_metadata)
|
self.sample_frames(video, metadata, num_frames, fps) for video, metadata in zip(videos, video_metadata)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
|
||||||
|
# moving the whole video incurs high GPU mem usage for long videos
|
||||||
|
if device is not None:
|
||||||
|
videos = [video.to(device) for video in videos]
|
||||||
|
|
||||||
# Group videos by size for batched resizing
|
# Group videos by size for batched resizing
|
||||||
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
||||||
resized_videos_grouped = {}
|
resized_videos_grouped = {}
|
||||||
|
|||||||
@@ -147,6 +147,7 @@ class InternVLVideoProcessor(BaseVideoProcessor):
|
|||||||
num_frames: Optional[int] = None,
|
num_frames: Optional[int] = None,
|
||||||
initial_shift: Optional[Union[bool, float, int]] = None,
|
initial_shift: Optional[Union[bool, float, int]] = None,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
device: Optional["torch.Tensor"] = None,
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
if do_sample_frames:
|
if do_sample_frames:
|
||||||
# Sample video frames
|
# Sample video frames
|
||||||
@@ -155,6 +156,11 @@ class InternVLVideoProcessor(BaseVideoProcessor):
|
|||||||
for video, metadata in zip(videos, video_metadata)
|
for video, metadata in zip(videos, video_metadata)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
|
||||||
|
# moving the whole video incurs high GPU mem usage for long videos
|
||||||
|
if device is not None:
|
||||||
|
videos = [video.to(device) for video in videos]
|
||||||
|
|
||||||
# Group videos by size for batched resizing
|
# Group videos by size for batched resizing
|
||||||
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
||||||
resized_videos_grouped = {}
|
resized_videos_grouped = {}
|
||||||
|
|||||||
@@ -213,6 +213,7 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor):
|
|||||||
min_frames: Optional[int] = None,
|
min_frames: Optional[int] = None,
|
||||||
max_frames: Optional[int] = None,
|
max_frames: Optional[int] = None,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
device: Optional["torch.Tensor"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if do_sample_frames:
|
if do_sample_frames:
|
||||||
@@ -230,6 +231,11 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor):
|
|||||||
for video, metadata in zip(videos, video_metadata)
|
for video, metadata in zip(videos, video_metadata)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
|
||||||
|
# moving the whole video incurs high GPU mem usage for long videos
|
||||||
|
if device is not None:
|
||||||
|
videos = [video.to(device) for video in videos]
|
||||||
|
|
||||||
# Group videos by size for batched resizing
|
# Group videos by size for batched resizing
|
||||||
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
||||||
resized_videos_grouped = {}
|
resized_videos_grouped = {}
|
||||||
|
|||||||
@@ -332,6 +332,7 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
|
|||||||
num_frames: Optional[int] = None,
|
num_frames: Optional[int] = None,
|
||||||
skip_secs: Optional[int] = 0,
|
skip_secs: Optional[int] = 0,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
device: Optional["torch.Tensor"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# Group videos by size for batched resizing
|
# Group videos by size for batched resizing
|
||||||
@@ -356,6 +357,11 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
|
|||||||
]
|
]
|
||||||
durations_list = [len(video) // 24 for video in videos]
|
durations_list = [len(video) // 24 for video in videos]
|
||||||
|
|
||||||
|
# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
|
||||||
|
# moving the whole video incurs high GPU mem usage for long videos
|
||||||
|
if device is not None:
|
||||||
|
videos = [video.to(device) for video in videos]
|
||||||
|
|
||||||
grouped_videos, grouped_videos_index = group_videos_by_shape(processed_videos)
|
grouped_videos, grouped_videos_index = group_videos_by_shape(processed_videos)
|
||||||
resized_videos_grouped = {}
|
resized_videos_grouped = {}
|
||||||
for shape, stacked_videos in grouped_videos.items():
|
for shape, stacked_videos in grouped_videos.items():
|
||||||
|
|||||||
@@ -158,6 +158,7 @@ from .utils import (
|
|||||||
is_torch_xpu_available,
|
is_torch_xpu_available,
|
||||||
is_torchao_available,
|
is_torchao_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
|
is_torchcodec_available,
|
||||||
is_torchdynamo_available,
|
is_torchdynamo_available,
|
||||||
is_torchvision_available,
|
is_torchvision_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
@@ -634,6 +635,16 @@ def require_torchvision(test_case):
|
|||||||
return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case)
|
return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_torchcodec(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires Torchcodec.
|
||||||
|
|
||||||
|
These tests are skipped when Torchcodec isn't installed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(is_torchcodec_available(), "test requires Torchvision")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_torch_or_tf(test_case):
|
def require_torch_or_tf(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires PyTorch or TensorFlow.
|
Decorator marking a test that requires PyTorch or TensorFlow.
|
||||||
|
|||||||
@@ -254,6 +254,7 @@ from .import_utils import (
|
|||||||
is_torch_xpu_available,
|
is_torch_xpu_available,
|
||||||
is_torchao_available,
|
is_torchao_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
|
is_torchcodec_available,
|
||||||
is_torchdistx_available,
|
is_torchdistx_available,
|
||||||
is_torchdynamo_available,
|
is_torchdynamo_available,
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
|
|||||||
@@ -119,6 +119,7 @@ _aqlm_available = _is_package_available("aqlm")
|
|||||||
_vptq_available, _vptq_version = _is_package_available("vptq", return_version=True)
|
_vptq_available, _vptq_version = _is_package_available("vptq", return_version=True)
|
||||||
_av_available = importlib.util.find_spec("av") is not None
|
_av_available = importlib.util.find_spec("av") is not None
|
||||||
_decord_available = importlib.util.find_spec("decord") is not None
|
_decord_available = importlib.util.find_spec("decord") is not None
|
||||||
|
_torchcodec_available = importlib.util.find_spec("torchcodec") is not None
|
||||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||||
_eetq_available = _is_package_available("eetq")
|
_eetq_available = _is_package_available("eetq")
|
||||||
_fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
|
_fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
|
||||||
@@ -976,6 +977,10 @@ def is_decord_available():
|
|||||||
return _decord_available
|
return _decord_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_torchcodec_available():
|
||||||
|
return _torchcodec_available
|
||||||
|
|
||||||
|
|
||||||
def is_ninja_available():
|
def is_ninja_available():
|
||||||
r"""
|
r"""
|
||||||
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
|
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
|
||||||
@@ -1502,6 +1507,14 @@ pip install decord
|
|||||||
Please note that you may need to restart your runtime after installation.
|
Please note that you may need to restart your runtime after installation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
TORCHCODEC_IMPORT_ERROR = """
|
||||||
|
{0} requires the TorchCodec (https://github.com/pytorch/torchcodec) library, but it was not found in your environment. You can install it with:
|
||||||
|
```
|
||||||
|
pip install torchcodec
|
||||||
|
```
|
||||||
|
Please note that you may need to restart your runtime after installation.
|
||||||
|
"""
|
||||||
|
|
||||||
# docstyle-ignore
|
# docstyle-ignore
|
||||||
CV2_IMPORT_ERROR = """
|
CV2_IMPORT_ERROR = """
|
||||||
{0} requires the OpenCV library but it was not found in your environment. You can install it with:
|
{0} requires the OpenCV library but it was not found in your environment. You can install it with:
|
||||||
@@ -1882,6 +1895,7 @@ BACKENDS_MAPPING = OrderedDict(
|
|||||||
("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
|
("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
|
||||||
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
||||||
("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)),
|
("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)),
|
||||||
|
("torchcodec", (is_torchcodec_available, TORCHCODEC_IMPORT_ERROR)),
|
||||||
("vision", (is_vision_available, VISION_IMPORT_ERROR)),
|
("vision", (is_vision_available, VISION_IMPORT_ERROR)),
|
||||||
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
||||||
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
|
("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
|
||||||
|
|||||||
@@ -294,7 +294,6 @@ class BaseVideoProcessor(BaseImageProcessorFast):
|
|||||||
videos: VideoInput,
|
videos: VideoInput,
|
||||||
video_metadata: VideoMetadata = None,
|
video_metadata: VideoMetadata = None,
|
||||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
device: Optional["torch.device"] = None,
|
|
||||||
) -> list["torch.Tensor"]:
|
) -> list["torch.Tensor"]:
|
||||||
"""
|
"""
|
||||||
Prepare the input videos for processing.
|
Prepare the input videos for processing.
|
||||||
@@ -313,10 +312,6 @@ class BaseVideoProcessor(BaseImageProcessorFast):
|
|||||||
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
|
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
|
||||||
video = torch.from_numpy(video).contiguous()
|
video = torch.from_numpy(video).contiguous()
|
||||||
|
|
||||||
# Now that we have torch tensors, we can move them to the right device
|
|
||||||
if device is not None:
|
|
||||||
video = video.to(device)
|
|
||||||
|
|
||||||
processed_videos.append(video)
|
processed_videos.append(video)
|
||||||
return processed_videos, batch_metadata
|
return processed_videos, batch_metadata
|
||||||
|
|
||||||
@@ -336,10 +331,9 @@ class BaseVideoProcessor(BaseImageProcessorFast):
|
|||||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||||
|
|
||||||
input_data_format = kwargs.pop("input_data_format")
|
input_data_format = kwargs.pop("input_data_format")
|
||||||
device = kwargs.pop("device")
|
|
||||||
video_metadata = kwargs.pop("video_metadata")
|
video_metadata = kwargs.pop("video_metadata")
|
||||||
videos, video_metadata = self._prepare_input_videos(
|
videos, video_metadata = self._prepare_input_videos(
|
||||||
videos=videos, video_metadata=video_metadata, input_data_format=input_data_format, device=device
|
videos=videos, video_metadata=video_metadata, input_data_format=input_data_format
|
||||||
)
|
)
|
||||||
|
|
||||||
kwargs = self._further_process_kwargs(**kwargs)
|
kwargs = self._further_process_kwargs(**kwargs)
|
||||||
@@ -378,6 +372,7 @@ class BaseVideoProcessor(BaseImageProcessorFast):
|
|||||||
fps: Optional[int] = None,
|
fps: Optional[int] = None,
|
||||||
num_frames: Optional[int] = None,
|
num_frames: Optional[int] = None,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
device: Optional["torch.Tensor"] = None,
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
if do_sample_frames:
|
if do_sample_frames:
|
||||||
# Sample video frames
|
# Sample video frames
|
||||||
@@ -386,6 +381,11 @@ class BaseVideoProcessor(BaseImageProcessorFast):
|
|||||||
for video, metadata in zip(videos, video_metadata)
|
for video, metadata in zip(videos, video_metadata)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
|
||||||
|
# moving the whole video incurs high GPU mem usage for long videos
|
||||||
|
if device is not None:
|
||||||
|
videos = [video.to(device) for video in videos]
|
||||||
|
|
||||||
# Group videos by size for batched resizing
|
# Group videos by size for batched resizing
|
||||||
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
||||||
resized_videos_grouped = {}
|
resized_videos_grouped = {}
|
||||||
@@ -775,6 +775,8 @@ class BaseVideoProcessor(BaseImageProcessorFast):
|
|||||||
`dict[str, Any]`: Dictionary of all the attributes that make up this video processor instance.
|
`dict[str, Any]`: Dictionary of all the attributes that make up this video processor instance.
|
||||||
"""
|
"""
|
||||||
output = copy.deepcopy(self.__dict__)
|
output = copy.deepcopy(self.__dict__)
|
||||||
|
output.pop("model_valid_processing_keys", None)
|
||||||
|
output.pop("_valid_kwargs_names", None)
|
||||||
output["video_processor_type"] = self.__class__.__name__
|
output["video_processor_type"] = self.__class__.__name__
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from contextlib import redirect_stdout
|
from contextlib import redirect_stdout
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -33,6 +34,7 @@ from .utils import (
|
|||||||
is_numpy_array,
|
is_numpy_array,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_tensor,
|
is_torch_tensor,
|
||||||
|
is_torchcodec_available,
|
||||||
is_torchvision_available,
|
is_torchvision_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
is_yt_dlp_available,
|
is_yt_dlp_available,
|
||||||
@@ -425,6 +427,10 @@ def read_video_torchvision(
|
|||||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||||
- `VideoMetadata` object.
|
- `VideoMetadata` object.
|
||||||
"""
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
"Using `torchvision` for video decoding is deprecated and will be removed in future versions. "
|
||||||
|
"Please use `torchcodec` instead."
|
||||||
|
)
|
||||||
video, _, info = torchvision_io.read_video(
|
video, _, info = torchvision_io.read_video(
|
||||||
video_path,
|
video_path,
|
||||||
start_pts=0.0,
|
start_pts=0.0,
|
||||||
@@ -449,11 +455,59 @@ def read_video_torchvision(
|
|||||||
return video, metadata
|
return video, metadata
|
||||||
|
|
||||||
|
|
||||||
|
def read_video_torchcodec(
|
||||||
|
video_path: str,
|
||||||
|
sample_indices_fn: Callable,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Decode the video with torchcodec decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path (`str`):
|
||||||
|
Path to the video file.
|
||||||
|
sample_indices_fn (`Callable`, *optional*):
|
||||||
|
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||||
|
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||||
|
If not provided, simple uniform sampling with fps is performed.
|
||||||
|
Example:
|
||||||
|
def sample_indices_fn(metadata, **kwargs):
|
||||||
|
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
|
||||||
|
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||||
|
- `VideoMetadata` object.
|
||||||
|
"""
|
||||||
|
# Lazy import torchcodec
|
||||||
|
requires_backends(read_video_torchcodec, ["torchcodec"])
|
||||||
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
|
||||||
|
decoder = VideoDecoder(
|
||||||
|
video_path,
|
||||||
|
dimension_order="NHWC", # to be consistent with other decoders
|
||||||
|
# Interestingly `exact` mode takes less than approximate when we load the whole video
|
||||||
|
seek_mode="exact",
|
||||||
|
)
|
||||||
|
metadata = VideoMetadata(
|
||||||
|
total_num_frames=decoder.metadata.num_frames,
|
||||||
|
fps=decoder.metadata.average_fps,
|
||||||
|
duration=decoder.metadata.duration_seconds,
|
||||||
|
video_backend="torchcodec",
|
||||||
|
)
|
||||||
|
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||||
|
|
||||||
|
video = decoder.get_frames_at(indices=indices).data.contiguous()
|
||||||
|
metadata.frames_indices = indices
|
||||||
|
return video, metadata
|
||||||
|
|
||||||
|
|
||||||
VIDEO_DECODERS = {
|
VIDEO_DECODERS = {
|
||||||
"decord": read_video_decord,
|
"decord": read_video_decord,
|
||||||
"opencv": read_video_opencv,
|
"opencv": read_video_opencv,
|
||||||
"pyav": read_video_pyav,
|
"pyav": read_video_pyav,
|
||||||
"torchvision": read_video_torchvision,
|
"torchvision": read_video_torchvision,
|
||||||
|
"torchcodec": read_video_torchcodec,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -477,7 +531,7 @@ def load_video(
|
|||||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||||
If not specified and `num_frames==None`, all frames are sampled.
|
If not specified and `num_frames==None`, all frames are sampled.
|
||||||
backend (`str`, *optional*, defaults to `"pyav"`):
|
backend (`str`, *optional*, defaults to `"pyav"`):
|
||||||
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav".
|
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision", "torchcodec"]. Defaults to "pyav".
|
||||||
sample_indices_fn (`Callable`, *optional*):
|
sample_indices_fn (`Callable`, *optional*):
|
||||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||||
@@ -535,7 +589,7 @@ def load_video(
|
|||||||
video_is_url = video.startswith("http://") or video.startswith("https://")
|
video_is_url = video.startswith("http://") or video.startswith("https://")
|
||||||
if video_is_url and backend in ["opencv", "torchvision"]:
|
if video_is_url and backend in ["opencv", "torchvision"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If you are trying to load a video from URL, you can decode the video only with `pyav` or `decord` as backend"
|
"If you are trying to load a video from URL, you can decode the video only with `pyav`, `decord` or `torchcodec` as backend"
|
||||||
)
|
)
|
||||||
|
|
||||||
if file_obj is None:
|
if file_obj is None:
|
||||||
@@ -546,6 +600,7 @@ def load_video(
|
|||||||
or (not is_av_available() and backend == "pyav")
|
or (not is_av_available() and backend == "pyav")
|
||||||
or (not is_cv2_available() and backend == "opencv")
|
or (not is_cv2_available() and backend == "opencv")
|
||||||
or (not is_torchvision_available() and backend == "torchvision")
|
or (not is_torchvision_available() and backend == "torchvision")
|
||||||
|
or (not is_torchcodec_available() and backend == "torchcodec")
|
||||||
):
|
):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
|
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from transformers.testing_utils import (
|
|||||||
require_cv2,
|
require_cv2,
|
||||||
require_decord,
|
require_decord,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torchcodec,
|
||||||
require_torchvision,
|
require_torchvision,
|
||||||
require_vision,
|
require_vision,
|
||||||
)
|
)
|
||||||
@@ -261,6 +262,7 @@ class LoadVideoTester(unittest.TestCase):
|
|||||||
|
|
||||||
@require_decord
|
@require_decord
|
||||||
@require_torchvision
|
@require_torchvision
|
||||||
|
@require_torchcodec
|
||||||
@require_cv2
|
@require_cv2
|
||||||
def test_load_video_backend_url(self):
|
def test_load_video_backend_url(self):
|
||||||
video, _ = load_video(
|
video, _ = load_video(
|
||||||
@@ -269,6 +271,12 @@ class LoadVideoTester(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(video.shape, (243, 360, 640, 3))
|
self.assertEqual(video.shape, (243, 360, 640, 3))
|
||||||
|
|
||||||
|
video, _ = load_video(
|
||||||
|
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||||
|
backend="torchcodec",
|
||||||
|
)
|
||||||
|
self.assertEqual(video.shape, (243, 360, 640, 3))
|
||||||
|
|
||||||
# Can't use certain backends with url
|
# Can't use certain backends with url
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
video, _ = load_video(
|
video, _ = load_video(
|
||||||
@@ -283,6 +291,7 @@ class LoadVideoTester(unittest.TestCase):
|
|||||||
|
|
||||||
@require_decord
|
@require_decord
|
||||||
@require_torchvision
|
@require_torchvision
|
||||||
|
@require_torchcodec
|
||||||
@require_cv2
|
@require_cv2
|
||||||
def test_load_video_backend_local(self):
|
def test_load_video_backend_local(self):
|
||||||
video_file_path = hf_hub_download(
|
video_file_path = hf_hub_download(
|
||||||
@@ -300,6 +309,10 @@ class LoadVideoTester(unittest.TestCase):
|
|||||||
self.assertEqual(video.shape, (243, 360, 640, 3))
|
self.assertEqual(video.shape, (243, 360, 640, 3))
|
||||||
self.assertIsInstance(metadata, VideoMetadata)
|
self.assertIsInstance(metadata, VideoMetadata)
|
||||||
|
|
||||||
|
video, metadata = load_video(video_file_path, backend="torchcodec")
|
||||||
|
self.assertEqual(video.shape, (243, 360, 640, 3))
|
||||||
|
self.assertIsInstance(metadata, VideoMetadata)
|
||||||
|
|
||||||
def test_load_video_num_frames(self):
|
def test_load_video_num_frames(self):
|
||||||
video, _ = load_video(
|
video, _ = load_video(
|
||||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||||
|
|||||||
Reference in New Issue
Block a user