Coverage for src/videodataset/dataset/base_dataset.py: 71%
34 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-12-25 11:33 +0800
« prev ^ index » next coverage.py v7.11.3, created at 2025-12-25 11:33 +0800
1from __future__ import annotations
3import logging
4from pathlib import Path
6import torch
8from videodataset import VideoDecoder
10logger = logging.getLogger(__name__)
13class BaseVideoDataset:
14 """Decoder extension that defines decoder specific functionalities"""
16 def __init__(self) -> None:
17 """Initialize the BaseVideoDataset with a dictionary to hold decoders
18 and set the device_id to the current CUDA device if available.
19 Raises a RuntimeError if no CUDA device is found.
20 """
21 self.decoders: dict[str, VideoDecoder] = {}
23 if torch.cuda.is_available():
24 self.device_id = torch.cuda.current_device()
25 else:
26 err_msg = "No cuda device found, accelerated decoding is not available"
27 raise RuntimeError(err_msg)
29 @property
30 def device(self) -> int:
31 """Return the device ID where decoders are running."""
32 return self.device_id
34 @property
35 def num_decoders(self) -> int:
36 """Return the number of decoders currently managed by the dataset."""
37 return len(self.decoders)
39 def get_decoder(self, decoder_key: str, codec: str) -> VideoDecoder:
40 """Retrieve a VideoDecoder for a specific key and codec. If the decoder
41 does not exist, it creates a new one and logs the creation.
42 """
43 if decoder_key not in self.decoders:
44 self.decoders[decoder_key] = VideoDecoder(self.device_id, codec)
45 logger.debug(
46 "Created VideoDecoder %s with codec %s on device %s",
47 decoder_key,
48 codec,
49 self.device_id,
50 )
51 return self.decoders[decoder_key]
53 def decode_video_frames(
54 self,
55 decoder: VideoDecoder,
56 video_path: str | Path,
57 frame_indices: list[int],
58 to_cpu: bool = False,
59 ) -> list[torch.Tensor]:
60 """Decode specific frames from a video file using the provided decoder.
61 Converts the decoded frames from NV12 format to RGB and optionally moves
62 the tensors to the CPU.
63 """
64 decoded_frames = decoder.decode_to_nps(str(video_path), frame_indices)
66 rgb_tensors = []
67 for np_frame in decoded_frames:
68 rgb_tensor = torch.from_numpy(np_frame)
69 rgb_tensors.append(
70 rgb_tensor.cuda(decoder.gpu_id()) if not to_cpu else rgb_tensor
71 )
73 return rgb_tensors
75 def decode_video_frame(
76 self,
77 decoder: VideoDecoder,
78 video_path: str | Path,
79 frame_idx: int,
80 to_cpu: bool = False,
81 ) -> torch.Tensor:
82 """Decode a specific frame from a video file using the provided decoder.
83 Converts the decoded frame from NV12 format to RGB and optionally moves
84 the tensor to the CPU.
85 """
86 decoded_frame = decoder.decode_to_tensor(str(video_path), frame_idx)
87 return decoded_frame.cpu() if to_cpu else decoded_frame