split_dataset_group_by#
- ethology.datasets.split.split_dataset_group_by(dataset, group_by_var, list_fractions, samples_coordinate='image_id', method='auto', seed=42, epsilon=0)[source]#
Split an annotations dataset by grouping variable.
Split an
ethologyannotations dataset into two subsets ensuring that the subsets are disjoint in the grouping variable (i.e., no group appears in both subsets). The function automatically chooses between a “group k-fold” approach and an “approximate subset-sum” approach based on the number of unique groups and requested split fractions.- Parameters:
dataset (xarray.Dataset) – The annotations dataset to split.
group_by_var (str) – The grouping variable to use for splitting the dataset. Must be 1-dimensional along the
samples_coordinate.list_fractions (list[float, float]) – The fractions of the input annotations dataset to allocate to each subset. Must contain only two elements and sum to 1.
samples_coordinate (str, optional) – The coordinate along which to split the dataset. Default is
image_id.method (str, optional) – Method to use:
auto,kfold, orapss. Whenauto, it automatically selects betweenkfoldorapssbased on the number of unique groups. See Notes for further details. Default isauto.seed (int, optional) – Random seed for reproducibility used in the “group k-fold” approach. Controls both the shuffling of the sample indices and the random selection of the output split from all possible ones. Only used when
methodiskfoldor whenautoselectskfold. Default is 42.epsilon (float, optional) – The approximation tolerance for the “approximate subset-sum” approach, expressed as a fraction of 1. The sum of samples in the smallest subset is guaranteed to be at least
(1 - epsilon)times the optimal sum for the requested fractions and grouping variable. Whenepsilonis 0, the algorithm finds the exact optimal sum. Larger values result in faster computation but may yield subsets with a total number of samples further from the optimal. Only used whenmethodisapssor whenautoselectsapss. Default is 0.
- Returns:
The two subsets of the input dataset. The subsets are returned in the same order as the input list of fractions
list_fractions.- Return type:
- Raises:
ValueError – If the elements of
list_fractionsare not exactly two, are not between 0 and 1, or do not sum to 1. Ifgroup_by_varis not 1-dimensional along thesamples_coordinate. Ifmethodiskfoldbut there are insufficient groups for the requested split fractions.
Notes
When
methodisauto, the function automatically selects between two approaches:Group k-fold method (default when sufficient groups exist): used when the number of unique groups is greater than or equal to the number of required folds (calculated as
1 / min(list_fractions)). This method computes all possible partitions of groups into folds and randomly selects one of them as the output split. The selection is controlled by theseedparameter for reproducibility. We usesklearn.model_selection.GroupKFoldcross-validator to compute all possible partitions of groups into folds.Approximate subset-sum method (fallback): used when there are too few unique groups for group k-fold splitting. This method deterministically finds a subset of groups whose combined sample count best matches the requested fractions. The
epsilonparameter controls the speed-accuracy tradeoff. Whenepsilonis 0, the algorithm finds the exact optimal sum. Larger values ofepsilonresult in faster computation but may yield subsets with a total number of samples further from the optimal. In cases where no valid split exists (e.g., all groups have more samples than the target), one subset may be empty and a warning is logged.
See also
sklearn.model_selection.GroupKFoldGroup k-fold cross-validator.
Examples
Split a dataset of 100 images extracted from 10 different videos. The xarray dataset has a single data variable
video_iddefined along theimage_iddimension. We would like to compute an 80/20 split, ensuring the subsets of the dataset are disjoint in the grouping variablevideo_id. Since there are many unique groups (i.e., unique video IDs), the function automatically selects the “group k-fold” method.>>> import xarray as xr >>> from ethology.datasets.split import split_dataset_group_by >>> ds_large = xr.Dataset( ... data_vars=dict( ... video_id=("image_id", np.tile(np.arange(10), 10)), ... ), # 10 different video IDs across 100 images ... coords=dict(image_id=range(100)), ... ) >>> ds_subset_1, ds_subset_2 = split_dataset_group_by( ... ds_large, "video_id", [0.8, 0.2], seed=42 ... ) >>> print(len(ds_subset_1.image_id) / len(ds_large.image_id)) # 0.8 >>> print(len(ds_subset_2.image_id) / len(ds_large.image_id)) # 0.2
Using different seeds produces different splits when the “group k-fold” method is used:
>>> ds_subset_1b, ds_subset_2b = split_dataset_group_by( ... ds_large, "video_id", [0.8, 0.2], seed=123 ... ) >>> assert not ds_subset_1.equals(ds_subset_1b) >>> assert not ds_subset_2.equals(ds_subset_2b) >>> print(len(ds_subset_1b.image_id) / len(ds_large.image_id)) # 0.8 >>> print(len(ds_subset_2b.image_id) / len(ds_large.image_id)) # 0.2
The function automatically selects the appropriate method. In the example below, a smaller dataset with 3 unique video IDs is used. With a 0.2 minimum fraction (requiring 5 folds), the “group k-fold” method cannot be used, since there would be more folds than groups. Therefore, the “approximate subset-sum” method is selected automatically and an approximate split is returned. Note that when using the
apssmethod, the seed value is ignored.>>> ds_small = xr.Dataset( ... data_vars=dict( ... video_id=("image_id", [1, 0, 0, 2, 2, 2, 0, 0, 0, 0, 0, 1]), ... ), ... coords=dict(image_id=range(12)), ... ) >>> ds_subset_1, ds_subset_2 = split_dataset_group_by( ... ds_small, ... group_by_var="video_id", ... list_fractions=[0.8, 0.2], ... ) >>> print(len(ds_subset_1.image_id) / len(ds_small.image_id)) # 0.833 >>> print(len(ds_subset_2.image_id) / len(ds_small.image_id)) # 0.166
The
epsilonparameter controls the approximation for the subset-sum method when auto-selected or explicitly specified:>>> ds_subset_1, ds_subset_2 = split_dataset_group_by( ... ds_small, ... group_by_var="video_id", ... list_fractions=[0.8, 0.2], ... epsilon=0.1, # accept a solution >= 90% of the optimal ... method="apss", ... ) >>> print(len(ds_subset_1.image_id) / len(ds_small.image_id)) # 0.833 >>> print(len(ds_subset_2.image_id) / len(ds_small.image_id)) # 0.166