Skip to content

zamba.pytorch.dataloaders

get_datasets(train_metadata=None, predict_metadata=None, transform=None, video_loader_config=None)

Gets training and/or prediction datasets.

Parameters:

Name Type Description Default
train_metadata pathlike

Path to a CSV or DataFrame with columns: - filepath: path to a video, relative to video_dir - label:, label of the species that appears in the video - split (optional): If provided, "train", "val", or "holdout" indicating which dataset split the video will be included in. If not provided, and a "site" column exists, generate a site-specific split. Otherwise, generate a random split using split_proportions. - site (optional): If no "split" column, generate a site-specific split using the values in this column.

None
predict_metadata pathlike

Path to a CSV or DataFrame with a "filepath" column.

None

Returns:

Type Description
Optional[FfmpegZambaVideoDataset]

A tuple of (train_dataset, val_dataset, test_dataset, predict_dataset) where each dataset

Optional[FfmpegZambaVideoDataset]

can be None if not specified.

Source code in zamba/pytorch/dataloaders.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def get_datasets(
    train_metadata: Optional[pd.DataFrame] = None,
    predict_metadata: Optional[pd.DataFrame] = None,
    transform: Optional[torchvision.transforms.transforms.Compose] = None,
    video_loader_config: Optional[VideoLoaderConfig] = None,
) -> Tuple[
    Optional["FfmpegZambaVideoDataset"],
    Optional["FfmpegZambaVideoDataset"],
    Optional["FfmpegZambaVideoDataset"],
    Optional["FfmpegZambaVideoDataset"],
]:
    """Gets training and/or prediction datasets.

    Args:
        train_metadata (pathlike, optional): Path to a CSV or DataFrame with columns:
          - filepath: path to a video, relative to `video_dir`
          - label:, label of the species that appears in the video
          - split (optional): If provided, "train", "val", or "holdout" indicating which dataset
            split the video will be included in. If not provided, and a "site" column exists,
            generate a site-specific split. Otherwise, generate a random split using
            `split_proportions`.
          - site (optional): If no "split" column, generate a site-specific split using the values
            in this column.
        predict_metadata (pathlike, optional): Path to a CSV or DataFrame with a "filepath" column.
        transform (torchvision.transforms.transforms.Compose, optional)
        video_loader_config (VideoLoaderConfig, optional)

    Returns:
        A tuple of (train_dataset, val_dataset, test_dataset, predict_dataset) where each dataset
        can be None if not specified.
    """
    if predict_metadata is not None:
        # enable filtering the same way on all datasets
        predict_metadata["species_"] = 0

    def subset_metadata_or_none(
        metadata: Optional[pd.DataFrame] = None, subset: Optional[str] = None
    ) -> Optional[pd.DataFrame]:
        if metadata is None:
            return None
        else:
            metadata_subset = metadata.loc[metadata.split == subset] if subset else metadata
            if len(metadata_subset) > 0:
                return FfmpegZambaVideoDataset(
                    annotations=metadata_subset.set_index("filepath").filter(regex="species"),
                    transform=transform,
                    video_loader_config=video_loader_config,
                )
            else:
                return None

    train_dataset = subset_metadata_or_none(train_metadata, "train")
    val_dataset = subset_metadata_or_none(train_metadata, "val")
    test_dataset = subset_metadata_or_none(train_metadata, "holdout")
    predict_dataset = subset_metadata_or_none(predict_metadata)

    return train_dataset, val_dataset, test_dataset, predict_dataset