Coverage for src/videodataset/dataset/base_dataset.py: 71%

34 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-12-25 11:33 +0800

1from __future__ import annotations 

2 

3import logging 

4from pathlib import Path 

5 

6import torch 

7 

8from videodataset import VideoDecoder 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13class BaseVideoDataset: 

14 """Decoder extension that defines decoder specific functionalities""" 

15 

16 def __init__(self) -> None: 

17 """Initialize the BaseVideoDataset with a dictionary to hold decoders 

18 and set the device_id to the current CUDA device if available. 

19 Raises a RuntimeError if no CUDA device is found. 

20 """ 

21 self.decoders: dict[str, VideoDecoder] = {} 

22 

23 if torch.cuda.is_available(): 

24 self.device_id = torch.cuda.current_device() 

25 else: 

26 err_msg = "No cuda device found, accelerated decoding is not available" 

27 raise RuntimeError(err_msg) 

28 

29 @property 

30 def device(self) -> int: 

31 """Return the device ID where decoders are running.""" 

32 return self.device_id 

33 

34 @property 

35 def num_decoders(self) -> int: 

36 """Return the number of decoders currently managed by the dataset.""" 

37 return len(self.decoders) 

38 

39 def get_decoder(self, decoder_key: str, codec: str) -> VideoDecoder: 

40 """Retrieve a VideoDecoder for a specific key and codec. If the decoder 

41 does not exist, it creates a new one and logs the creation. 

42 """ 

43 if decoder_key not in self.decoders: 

44 self.decoders[decoder_key] = VideoDecoder(self.device_id, codec) 

45 logger.debug( 

46 "Created VideoDecoder %s with codec %s on device %s", 

47 decoder_key, 

48 codec, 

49 self.device_id, 

50 ) 

51 return self.decoders[decoder_key] 

52 

53 def decode_video_frames( 

54 self, 

55 decoder: VideoDecoder, 

56 video_path: str | Path, 

57 frame_indices: list[int], 

58 to_cpu: bool = False, 

59 ) -> list[torch.Tensor]: 

60 """Decode specific frames from a video file using the provided decoder. 

61 Converts the decoded frames from NV12 format to RGB and optionally moves 

62 the tensors to the CPU. 

63 """ 

64 decoded_frames = decoder.decode_to_nps(str(video_path), frame_indices) 

65 

66 rgb_tensors = [] 

67 for np_frame in decoded_frames: 

68 rgb_tensor = torch.from_numpy(np_frame) 

69 rgb_tensors.append( 

70 rgb_tensor.cuda(decoder.gpu_id()) if not to_cpu else rgb_tensor 

71 ) 

72 

73 return rgb_tensors 

74 

75 def decode_video_frame( 

76 self, 

77 decoder: VideoDecoder, 

78 video_path: str | Path, 

79 frame_idx: int, 

80 to_cpu: bool = False, 

81 ) -> torch.Tensor: 

82 """Decode a specific frame from a video file using the provided decoder. 

83 Converts the decoded frame from NV12 format to RGB and optionally moves 

84 the tensor to the CPU. 

85 """ 

86 decoded_frame = decoder.decode_to_tensor(str(video_path), frame_idx) 

87 return decoded_frame.cpu() if to_cpu else decoded_frame