test_mixin

tests/test_mixin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from __future__ import annotations

import random

import torch

from videodataset.dataset.base_dataset import BaseVideoDataset


def test_mixin_with_torch_dataset(test_video):
    """Test the mixin class BaseVideoDataset for torch dataset."""

    class MyDataset(torch.utils.data.Dataset, BaseVideoDataset):
        def __init__(
            self,
            number_of_frames: int,
            video_path: str,
            codec: str = "h265",
        ):
            super().__init__()
            self.video_path = video_path
            self.decoder = self.get_decoder(video_path, codec)
            self.number_of_frames = number_of_frames

        def __len__(self):
            return self.number_of_frames

        def __getitem__(self, idx):
            got = False
            while not got:
                try:
                    result = self.decode_video_frame(
                        self.decoder,
                        self.video_path,
                        idx,
                    )
                    got = True
                except Exception:
                    idx = random.randint(0, len(self.frames) - 1)
            return result

    num_of_frames = 10
    dataset = MyDataset(num_of_frames, str(test_video))
    for batch in torch.utils.data.DataLoader(dataset, batch_size=2):
        assert batch.shape == (2, 720, 1280, 3)