Source code for videodataset.dataset.base_dataset
from __future__ import annotations
import logging
from pathlib import Path
import torch
from videodataset import VideoDecoder
logger = logging.getLogger(__name__)
[docs]
class BaseVideoDataset:
"""Decoder extension that defines decoder specific functionalities"""
def __init__(self) -> None:
"""Initialize the BaseVideoDataset with a dictionary to hold decoders
and set the device_id to the current CUDA device if available.
Raises a RuntimeError if no CUDA device is found.
"""
self.decoders: dict[str, VideoDecoder] = {}
if torch.cuda.is_available():
self.device_id = torch.cuda.current_device()
else:
err_msg = "No cuda device found, accelerated decoding is not available"
raise RuntimeError(err_msg)
@property
def device(self) -> int:
"""Return the device ID where decoders are running."""
return self.device_id
@property
def num_decoders(self) -> int:
"""Return the number of decoders currently managed by the dataset."""
return len(self.decoders)
[docs]
def get_decoder(self, decoder_key: str, codec: str) -> VideoDecoder:
"""Retrieve a VideoDecoder for a specific key and codec. If the decoder
does not exist, it creates a new one and logs the creation.
"""
if decoder_key not in self.decoders:
self.decoders[decoder_key] = VideoDecoder(self.device_id, codec)
logger.debug(
"Created VideoDecoder %s with codec %s on device %s",
decoder_key,
codec,
self.device_id,
)
return self.decoders[decoder_key]
[docs]
def decode_video_frames(
self,
decoder: VideoDecoder,
video_path: str | Path,
frame_indices: list[int],
to_cpu: bool = False,
) -> list[torch.Tensor]:
"""Decode specific frames from a video file using the provided decoder.
Converts the decoded frames from NV12 format to RGB and optionally moves
the tensors to the CPU.
"""
decoded_frames = decoder.decode_to_nps(str(video_path), frame_indices)
rgb_tensors = []
for np_frame in decoded_frames:
rgb_tensor = torch.from_numpy(np_frame)
rgb_tensors.append(
rgb_tensor.cuda(decoder.gpu_id()) if not to_cpu else rgb_tensor
)
return rgb_tensors
[docs]
def decode_video_frame(
self,
decoder: VideoDecoder,
video_path: str | Path,
frame_idx: int,
to_cpu: bool = False,
) -> torch.Tensor:
"""Decode a specific frame from a video file using the provided decoder.
Converts the decoded frame from NV12 format to RGB and optionally moves
the tensor to the CPU.
"""
decoded_frame = decoder.decode_to_tensor(str(video_path), frame_idx)
return decoded_frame.cpu() if to_cpu else decoded_frame