split_dataset_group_by#

ethology.datasets.split.split_dataset_group_by(dataset, group_by_var, list_fractions, samples_coordinate='image_id', method='auto', seed=42, epsilon=0)[source]#

Split an annotations dataset by grouping variable.

Split an ethology annotations dataset into two subsets ensuring that the subsets are disjoint in the grouping variable (i.e., no group appears in both subsets). The function automatically chooses between a “group k-fold” approach and an “approximate subset-sum” approach based on the number of unique groups and requested split fractions.

Parameters:
  • dataset (xarray.Dataset) – The annotations dataset to split.

  • group_by_var (str) – The grouping variable to use for splitting the dataset. Must be 1-dimensional along the samples_coordinate.

  • list_fractions (list[float, float]) – The fractions of the input annotations dataset to allocate to each subset. Must contain only two elements and sum to 1.

  • samples_coordinate (str, optional) – The coordinate along which to split the dataset. Default is image_id.

  • method (str, optional) – Method to use: auto, kfold, or apss. When auto, it automatically selects between kfold or apss based on the number of unique groups. See Notes for further details. Default is auto.

  • seed (int, optional) – Random seed for reproducibility used in the “group k-fold” approach. Controls both the shuffling of the sample indices and the random selection of the output split from all possible ones. Only used when method is kfold or when auto selects kfold. Default is 42.

  • epsilon (float, optional) – The approximation tolerance for the “approximate subset-sum” approach, expressed as a fraction of 1. The sum of samples in the smallest subset is guaranteed to be at least (1 - epsilon) times the optimal sum for the requested fractions and grouping variable. When epsilon is 0, the algorithm finds the exact optimal sum. Larger values result in faster computation but may yield subsets with a total number of samples further from the optimal. Only used when method is apss or when auto selects apss. Default is 0.

Returns:

The two subsets of the input dataset. The subsets are returned in the same order as the input list of fractions list_fractions.

Return type:

tuple[xarray.Dataset, xarray.Dataset]

Raises:

ValueError – If the elements of list_fractions are not exactly two, are not between 0 and 1, or do not sum to 1. If group_by_var is not 1-dimensional along the samples_coordinate. If method is kfold but there are insufficient groups for the requested split fractions.

Notes

When method is auto, the function automatically selects between two approaches:

  • Group k-fold method (default when sufficient groups exist): used when the number of unique groups is greater than or equal to the number of required folds (calculated as 1 / min(list_fractions)). This method computes all possible partitions of groups into folds and randomly selects one of them as the output split. The selection is controlled by the seed parameter for reproducibility. We use sklearn.model_selection.GroupKFold cross-validator to compute all possible partitions of groups into folds.

  • Approximate subset-sum method (fallback): used when there are too few unique groups for group k-fold splitting. This method deterministically finds a subset of groups whose combined sample count best matches the requested fractions. The epsilon parameter controls the speed-accuracy tradeoff. When epsilon is 0, the algorithm finds the exact optimal sum. Larger values of epsilon result in faster computation but may yield subsets with a total number of samples further from the optimal. In cases where no valid split exists (e.g., all groups have more samples than the target), one subset may be empty and a warning is logged.

See also

sklearn.model_selection.GroupKFold

Group k-fold cross-validator.

Examples

Split a dataset of 100 images extracted from 10 different videos. The xarray dataset has a single data variable video_id defined along the image_id dimension. We would like to compute an 80/20 split, ensuring the subsets of the dataset are disjoint in the grouping variable video_id. Since there are many unique groups (i.e., unique video IDs), the function automatically selects the “group k-fold” method.

>>> import xarray as xr
>>> from ethology.datasets.split import split_dataset_group_by
>>> ds_large = xr.Dataset(
...     data_vars=dict(
...         video_id=("image_id", np.tile(np.arange(10), 10)),
...     ),  # 10 different video IDs across 100 images
...     coords=dict(image_id=range(100)),
... )
>>> ds_subset_1, ds_subset_2 = split_dataset_group_by(
...     ds_large, "video_id", [0.8, 0.2], seed=42
... )
>>> print(len(ds_subset_1.image_id) / len(ds_large.image_id))  # 0.8
>>> print(len(ds_subset_2.image_id) / len(ds_large.image_id))  # 0.2

Using different seeds produces different splits when the “group k-fold” method is used:

>>> ds_subset_1b, ds_subset_2b = split_dataset_group_by(
...     ds_large, "video_id", [0.8, 0.2], seed=123
... )
>>> assert not ds_subset_1.equals(ds_subset_1b)
>>> assert not ds_subset_2.equals(ds_subset_2b)
>>> print(len(ds_subset_1b.image_id) / len(ds_large.image_id))  # 0.8
>>> print(len(ds_subset_2b.image_id) / len(ds_large.image_id))  # 0.2

The function automatically selects the appropriate method. In the example below, a smaller dataset with 3 unique video IDs is used. With a 0.2 minimum fraction (requiring 5 folds), the “group k-fold” method cannot be used, since there would be more folds than groups. Therefore, the “approximate subset-sum” method is selected automatically and an approximate split is returned. Note that when using the apss method, the seed value is ignored.

>>> ds_small = xr.Dataset(
...     data_vars=dict(
...         video_id=("image_id", [1, 0, 0, 2, 2, 2, 0, 0, 0, 0, 0, 1]),
...     ),
...     coords=dict(image_id=range(12)),
... )
>>> ds_subset_1, ds_subset_2 = split_dataset_group_by(
...     ds_small,
...     group_by_var="video_id",
...     list_fractions=[0.8, 0.2],
... )
>>> print(len(ds_subset_1.image_id) / len(ds_small.image_id))  # 0.833
>>> print(len(ds_subset_2.image_id) / len(ds_small.image_id))  # 0.166

The epsilon parameter controls the approximation for the subset-sum method when auto-selected or explicitly specified:

>>> ds_subset_1, ds_subset_2 = split_dataset_group_by(
...     ds_small,
...     group_by_var="video_id",
...     list_fractions=[0.8, 0.2],
...     epsilon=0.1,  # accept a solution >= 90% of the optimal
...     method="apss",
... )
>>> print(len(ds_subset_1.image_id) / len(ds_small.image_id))  # 0.833
>>> print(len(ds_subset_2.image_id) / len(ds_small.image_id))  # 0.166

Examples using split_dataset_group_by#

Split an annotations dataset

Split an annotations dataset