"""Utilities for splitting annotations datasets."""
from collections import Counter
from typing import TypedDict
import numpy as np
import xarray as xr
from loguru import logger
from sklearn.model_selection import GroupKFold
[docs]
def split_dataset_group_by(
dataset: xr.Dataset,
group_by_var: str,
list_fractions: list[float],
samples_coordinate: str = "image_id",
method: str = "auto",
seed: int = 42,
epsilon: float = 0,
) -> tuple[xr.Dataset, xr.Dataset]:
"""Split an annotations dataset by grouping variable.
Split an ``ethology`` annotations dataset into two subsets ensuring that
the subsets are disjoint in the grouping variable (i.e., no group appears
in both subsets). The function automatically chooses between a
"group k-fold" approach and an "approximate subset-sum" approach based on
the number of unique groups and requested split fractions.
Parameters
----------
dataset : xarray.Dataset
The annotations dataset to split.
group_by_var : str
The grouping variable to use for splitting the dataset. Must be
1-dimensional along the ``samples_coordinate``.
list_fractions : list[float, float]
The fractions of the input annotations dataset to allocate to
each subset. Must contain only two elements and sum to 1.
samples_coordinate : str, optional
The coordinate along which to split the dataset. Default is
``image_id``.
method : str, optional
Method to use: ``auto``, ``kfold``, or ``apss``. When ``auto``,
it automatically selects between ``kfold`` or ``apss`` based on
the number of unique groups. See Notes for further details.
Default is ``auto``.
seed : int, optional
Random seed for reproducibility used in the "group k-fold" approach.
Controls both the shuffling of the sample indices and the random
selection of the output split from all possible ones. Only used when
``method`` is ``kfold`` or when ``auto`` selects ``kfold``.
Default is 42.
epsilon : float, optional
The approximation tolerance for the "approximate subset-sum" approach,
expressed as a fraction of 1. The sum of samples in the smallest subset
is guaranteed to be at least ``(1 - epsilon)`` times the optimal sum
for the requested fractions and grouping variable. When ``epsilon``
is 0, the algorithm finds the exact optimal sum. Larger values result
in faster computation but may yield subsets with a total number of
samples further from the optimal. Only used when ``method`` is
``apss`` or when ``auto`` selects ``apss``. Default is 0.
Returns
-------
tuple[xarray.Dataset, xarray.Dataset]
The two subsets of the input dataset. The subsets are returned in the
same order as the input list of fractions ``list_fractions``.
Raises
------
ValueError
If the elements of ``list_fractions`` are not exactly two, are not
between 0 and 1, or do not sum to 1. If ``group_by_var`` is not
1-dimensional along the ``samples_coordinate``. If ``method`` is
``kfold`` but there are insufficient groups for the requested split
fractions.
Notes
-----
When ``method`` is ``auto``, the function automatically selects between two
approaches:
- **Group k-fold method** (default when sufficient groups exist): used
when the number of unique groups is greater than or equal to the number
of required folds (calculated as ``1 / min(list_fractions)``). This
method computes all possible partitions of groups into folds and randomly
selects one of them as the output split. The selection is controlled by
the ``seed`` parameter for reproducibility. We use
:class:`sklearn.model_selection.GroupKFold` cross-validator to compute
all possible partitions of groups into folds.
- **Approximate subset-sum method** (fallback): used when there are too few
unique groups for group k-fold splitting. This method deterministically
finds a subset of groups whose combined sample count best matches the
requested fractions. The ``epsilon`` parameter controls the
speed-accuracy tradeoff. When ``epsilon`` is 0, the algorithm finds the
exact optimal sum. Larger values of ``epsilon`` result in faster
computation but may yield subsets with a total number of samples further
from the optimal. In cases where no valid split exists (e.g., all groups
have more samples than the target), one subset may be empty and a
warning is logged.
See Also
--------
:class:`sklearn.model_selection.GroupKFold` : Group k-fold cross-validator.
Examples
--------
Split a dataset of 100 images extracted from 10 different videos.
The xarray dataset has a single data variable ``video_id`` defined along
the ``image_id`` dimension. We would like to compute an 80/20 split,
ensuring the subsets of the dataset are disjoint in the grouping
variable ``video_id``. Since there are many unique groups
(i.e., unique video IDs), the function automatically selects
the "group k-fold" method.
>>> import xarray as xr
>>> from ethology.datasets.split import split_dataset_group_by
>>> ds_large = xr.Dataset(
... data_vars=dict(
... video_id=("image_id", np.tile(np.arange(10), 10)),
... ), # 10 different video IDs across 100 images
... coords=dict(image_id=range(100)),
... )
>>> ds_subset_1, ds_subset_2 = split_dataset_group_by(
... ds_large, "video_id", [0.8, 0.2], seed=42
... )
>>> print(len(ds_subset_1.image_id) / len(ds_large.image_id)) # 0.8
>>> print(len(ds_subset_2.image_id) / len(ds_large.image_id)) # 0.2
Using different seeds produces different splits when the "group k-fold"
method is used:
>>> ds_subset_1b, ds_subset_2b = split_dataset_group_by(
... ds_large, "video_id", [0.8, 0.2], seed=123
... )
>>> assert not ds_subset_1.equals(ds_subset_1b)
>>> assert not ds_subset_2.equals(ds_subset_2b)
>>> print(len(ds_subset_1b.image_id) / len(ds_large.image_id)) # 0.8
>>> print(len(ds_subset_2b.image_id) / len(ds_large.image_id)) # 0.2
The function automatically selects the appropriate method. In the example
below, a smaller dataset with 3 unique video IDs is used. With a 0.2
minimum fraction (requiring 5 folds), the "group k-fold" method cannot be
used, since there would be more folds than groups. Therefore, the
"approximate subset-sum" method is selected automatically and an
approximate split is returned. Note that when using the ``apss`` method,
the seed value is ignored.
>>> ds_small = xr.Dataset(
... data_vars=dict(
... video_id=("image_id", [1, 0, 0, 2, 2, 2, 0, 0, 0, 0, 0, 1]),
... ),
... coords=dict(image_id=range(12)),
... )
>>> ds_subset_1, ds_subset_2 = split_dataset_group_by(
... ds_small,
... group_by_var="video_id",
... list_fractions=[0.8, 0.2],
... )
>>> print(len(ds_subset_1.image_id) / len(ds_small.image_id)) # 0.833
>>> print(len(ds_subset_2.image_id) / len(ds_small.image_id)) # 0.166
The ``epsilon`` parameter controls the approximation for the subset-sum
method when auto-selected or explicitly specified:
>>> ds_subset_1, ds_subset_2 = split_dataset_group_by(
... ds_small,
... group_by_var="video_id",
... list_fractions=[0.8, 0.2],
... epsilon=0.1, # accept a solution >= 90% of the optimal
... method="apss",
... )
>>> print(len(ds_subset_1.image_id) / len(ds_small.image_id)) # 0.833
>>> print(len(ds_subset_2.image_id) / len(ds_small.image_id)) # 0.166
"""
# Checks
if sum(list_fractions) != 1:
raise ValueError("The split fractions must sum to 1.")
if len(list_fractions) != 2:
raise ValueError("The list of fractions must have only two elements.")
if any(fraction < 0 or fraction > 1 for fraction in list_fractions):
raise ValueError("The split fractions must be between 0 and 1.")
if len(dataset[group_by_var].shape) != 1:
raise ValueError(
f"The grouping variable {group_by_var} must be 1-dimensional along"
f" {samples_coordinate}."
)
# Count unique groups
n_unique_groups = len(np.unique(dataset[group_by_var].values))
n_required_folds = int(np.rint(1 / min(list_fractions)))
# Auto-select method
if method == "auto":
if n_unique_groups >= n_required_folds:
method = "kfold"
else:
method = "apss"
logger.info(
f"Only {n_unique_groups} unique groups exist but "
f"{n_required_folds} are required for k-fold method. "
"Auto-selected approximate subset-sum method "
f"with epsilon={epsilon}. Seed setting is ignored."
)
# Dispatch to appropriate method
if method == "kfold":
logger.info(
f"Using group k-fold method with {n_required_folds} folds "
f"and seed={seed}."
)
return _split_dataset_group_by_kfold(
dataset, group_by_var, list_fractions, samples_coordinate, seed
)
elif method == "apss":
logger.info(
f"Using approximate subset-sum method with epsilon={epsilon}."
)
return _split_dataset_group_by_apss(
dataset, group_by_var, list_fractions, epsilon, samples_coordinate
)
else:
raise ValueError(f"Unknown method: {method}")
def _split_dataset_group_by_kfold(
dataset: xr.Dataset,
group_by_var: str,
list_fractions: list[float],
samples_coordinate: str = "image_id",
seed: int = 42,
) -> tuple[xr.Dataset, xr.Dataset]:
"""Split an annotations dataset using scikit-learn's GroupKFold.
Split an ``ethology`` annotations dataset into two subsets
ensuring that the subsets are disjoint in the grouping variable.
This method uses scikit-learn's GroupKFold cross-validator to
randomly partition groups into folds and then selects one fold
as the smaller subset and the remaining folds as the larger subset.
Parameters
----------
dataset : xarray.Dataset
The annotations dataset to split.
group_by_var : str
The grouping variable to use for splitting the dataset. Must be
1-dimensional along the ``samples_coordinate``.
list_fractions : list[float, float]
The fractions of the input annotations dataset to allocate to
each subset. Must contain only two elements and sum to 1.
samples_coordinate : str, optional
The coordinate along which to split the dataset. Default is
``image_id``.
seed : int, optional
Random seed for reproducibility. Controls both the GroupKFold
indices shuffling and the random selection of the output split from
all the possible ones. Default is 42.
Returns
-------
tuple[xarray.Dataset, xarray.Dataset]
The two subsets of the input dataset. The subsets are returned in the
same order as the input list of fractions.
Raises
------
ValueError
If the elements of ``list_fractions`` are not exactly two, are not
between 0 and 1, or do not sum to 1. If ``group_by_var`` is not
1-dimensional along the ``samples_coordinate``.
Examples
--------
Split a dataset with a single data variable ``foo`` defined along the
``image_id`` dimension into an 80/20 split, ensuring that the
subsets are disjoint in the grouping variable ``foo``.
>>> from ethology.datasets.split import _split_dataset_group_by_kfold
>>> import xarray as xr
>>> ds = xr.Dataset(
>>> data_vars=dict(
>>> foo=("image_id", [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]),
>>> ), # 0: 10 counts, 1: 2 counts
>>> coords=dict(
>>> image_id=range(12),
>>> ),
>>> )
>>> ds_subset_1, ds_subset_2 = _split_dataset_group_by_kfold(
>>> ds,
>>> group_by_var="foo",
>>> list_fractions=[0.2, 0.8],
>>> seed=42,
>>> )
The ``seed`` parameter ensures reproducibility. Using the same seed
on the same dataset will always produce the same split. Using a different
seed will produce a different split.
>>> ds_subset_1b, ds_subset_2b = _split_dataset_group_by_kfold(
>>> ds,
>>> group_by_var="foo",
>>> list_fractions=[0.2, 0.8],
>>> seed=123,
>>> )
>>> assert not ds_subset_1.equals(ds_subset_1b)
>>> assert not ds_subset_2.equals(ds_subset_2b)
"""
# Initialise k-fold iterator
n_folds_per_shuffle = int(np.rint(1 / min(list_fractions)))
gkf = GroupKFold(
n_splits=n_folds_per_shuffle, shuffle=True, random_state=seed
)
# Compute all possible shuffles
# In each shuffle, one fold is the test set,
# the rest of folds make up the train set
train_test_idcs_per_shuffle = list(
gkf.split(
dataset[samples_coordinate].values,
groups=dataset[group_by_var].values,
)
)
# Randomly pick one of the shuffles
rng = np.random.default_rng(seed)
shuffle_idx = rng.choice(len(train_test_idcs_per_shuffle))
train_idcs, test_idcs = train_test_idcs_per_shuffle[shuffle_idx]
# Split the datasets
ds_train = dataset.isel({samples_coordinate: train_idcs})
ds_test = dataset.isel({samples_coordinate: test_idcs})
list_ds_sorted = [ds_test, ds_train] # sorted in increasing size
# Return datasets in the same order as the input list of fractions
idcs_sorted = np.argsort(list_fractions) # idcs to map input -> sorted
idcs_original = np.argsort(idcs_sorted) # idcs to map sorted -> input
return tuple(list_ds_sorted[i] for i in idcs_original)
def _split_dataset_group_by_apss(
dataset: xr.Dataset,
group_by_var: str,
list_fractions: list[float],
epsilon: float = 0,
samples_coordinate: str = "image_id",
) -> tuple[xr.Dataset, xr.Dataset]:
"""Split an annotations dataset using an approximate subset-sum approach.
Split an ``ethology`` annotations dataset into two subsets
ensuring that the subsets are disjoint in the grouping variable.
Parameters
----------
dataset : xarray.Dataset
The annotations dataset to split.
group_by_var : str
The grouping variable to use for splitting the dataset. Must be
1-dimensional along the ``samples_coordinate``.
list_fractions : list[float, float]
The fractions of the input annotations dataset to allocate to
each subset. Must contain only two elements and sum to 1.
epsilon : float, optional
The approximation tolerance for the subset sum as a fraction of 1.
The sum of samples in the smallest subset is guaranteed to be at
least ``(1 - epsilon)`` times the optimal sum for the requested
fractions and grouping variable. When ``epsilon`` is 0, the algorithm
finds the exact optimal sum. Larger values result in faster
computation but may yield subsets with a total number of samples
further from optimal. Default is 0.
samples_coordinate : str, optional
The coordinate along which to split the dataset. Default is
``image_id``.
Returns
-------
tuple[xarray.Dataset, xarray.Dataset]
The two subsets of the input dataset. The subsets are returned in the
same order as the input list of fractions.
Raises
------
ValueError
If the elements of ``list_fractions`` are not exactly two, are not
between 0 and 1, or do not sum to 1. If ``group_by_var`` is not
1-dimensional along the ``samples_coordinate``.
Examples
--------
Split a dataset with a single data variable ``foo`` defined along the
``image_id`` dimension into an approximate 80/20 split, ensuring that the
subsets are disjoint in the grouping variable ``foo``.
>>> from ethology.datasets.split import _split_dataset_group_by_apss
>>> import xarray as xr
>>> ds = xr.Dataset(
>>> data_vars=dict(
>>> foo=("image_id", [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]),
>>> ), # 0: 10 counts, 1: 2 counts
>>> coords=dict(
>>> image_id=range(12),
>>> ),
>>> )
>>> ds_subset_1, ds_subset_2 = _split_dataset_group_by_apss(
>>> ds,
>>> group_by_var="foo",
>>> list_fractions=[0.2, 0.8],
>>> epsilon=0,
>>> )
>>> print(len(ds_subset_1.image_id) / len(ds.image_id)) # 0.166
>>> print(len(ds_subset_2.image_id) / len(ds.image_id)) # 0.833
"""
# Compute number of samples in target subset
# the target subset is the subset with the smallest fraction.
target_subset_count = int(
min(list_fractions) * len(dataset.get(samples_coordinate))
)
# Get list of (id, count) tuples
# Count number of samples per group and sort by count in ascending order
count_per_group_id = Counter(dataset[group_by_var].values).most_common()[
::-1
]
# Cast group ids to integers and create mapping
map_group_id_int_to_original = {}
count_per_group_id_as_int = []
for idx, (group_id, count) in enumerate(count_per_group_id):
# try casting group ID as integer
try:
int_id = int(group_id)
# if not castable: use ID from enumerate
except (ValueError, TypeError):
int_id = idx
count_per_group_id_as_int.append((int_id, count))
map_group_id_int_to_original[int_id] = group_id
# Get group ids (as integers) for target subset
subset_dict = _approximate_subset_sum(
count_per_group_id_as_int,
target_subset_count,
epsilon=epsilon,
)
# Get original group IDs (they are not necessarily integers)
subset_group_ids = [
map_group_id_int_to_original[x] for x in subset_dict["ids"]
]
# Extract datasets for target subset and not target subset
ds_subset = dataset.isel(
{samples_coordinate: dataset[group_by_var].isin(subset_group_ids)}
)
ds_not_subset = dataset.isel(
{samples_coordinate: ~dataset[group_by_var].isin(subset_group_ids)}
)
# Throw warning if a subset is empty
if any(len(ds.image_id) == 0 for ds in [ds_subset, ds_not_subset]):
logger.warning("At least one of the subset datasets is empty.")
# Return datasets in the same order as the input list of fractions
# (argsort twice gives the inverse permutation)
idcs_sorted = np.argsort(list_fractions) # idcs to map input -> sorted
idcs_original = np.argsort(idcs_sorted) # idcs to map sorted -> input
list_ds_sorted = [ds_subset, ds_not_subset]
return tuple(list_ds_sorted[i] for i in idcs_original)
[docs]
def split_dataset_random(
dataset: xr.Dataset,
list_fractions: list[float],
seed: int = 42,
samples_coordinate: str = "image_id",
) -> tuple[xr.Dataset, ...]:
"""Split an annotations dataset using random sampling.
Split an ``ethology`` annotations dataset into multiple subsets by randomly
shuffling all samples and then partitioning them sequentially according to
the specified fractions.
Parameters
----------
dataset : xarray.Dataset
The annotations dataset to split.
list_fractions : list[float, ...]
The fractions of the input annotations dataset to allocate to
each subset. The list must contain at least two elements, all elements
must be between 0 and 1, and add up to 1.
seed : int, optional
Seed to use for the random number generator. Default is 42.
samples_coordinate : str, optional
The coordinate along which to split the dataset. Default is
``image_id``.
Returns
-------
tuple[xarray.Dataset, ...]
The subsets of the input dataset. The subsets are returned in the
same order as the input list of fractions.
Raises
------
ValueError
If the elements of ``list_fractions`` are less than two, are not
between 0 and 1, or do not sum to 1.
Examples
--------
Split a dataset with a single data variable ``foo``, with 100 values
defined along the ``image_id`` dimension into 70/20/10 splits.
>>> from ethology.datasets.split import split_dataset_random
>>> import numpy as np
>>> import xarray as xr
>>> ds = xr.Dataset(
... data_vars=dict(
... foo=("image_id", np.random.randint(0, 100, size=100)),
... ),
... coords=dict(
... image_id=range(100),
... ),
... )
>>> ds_train, ds_val, ds_test = split_dataset_random(
... ds,
... list_fractions=[0.7, 0.2, 0.1],
... seed=42,
... )
>>> print(len(ds_train.image_id)) # 70
>>> print(len(ds_val.image_id)) # 20
>>> print(len(ds_test.image_id)) # 10
Notes
-----
The function operates in two steps: first, it shuffles all sample indices
along the ``samples_coordinate`` dimension using the provided random seed;
then, it partitions the shuffled indices into contiguous blocks, one for
each subset.
The size of each block is determined by rounding down (floor) the product
of the subset's fraction and the total number of samples. To ensure all
samples are included, the last subset receives any remaining samples after
the earlier subsets have been allocated. Due to this rounding behavior,
the actual fraction for the last subset may differ slightly from the
requested fraction.
"""
# Checks
if len(list_fractions) < 2:
raise ValueError(
"The list of fractions must have at least two elements."
)
if any(fraction < 0 or fraction > 1 for fraction in list_fractions):
raise ValueError("The split fractions must be between 0 and 1.")
if sum(list_fractions) != 1:
raise ValueError("The split fractions must sum to 1.")
# Compute number of samples for each split
list_n_samples: list[int] = []
n_total_samples = len(dataset.get(samples_coordinate))
for fraction in list_fractions[:-1]:
list_n_samples.append(int(fraction * n_total_samples))
# append the remaining samples to the last split
list_n_samples.append(n_total_samples - sum(list_n_samples))
# Shuffle all indices
rng = np.random.default_rng(seed)
shuffled_idcs = rng.permutation(n_total_samples)
# Extract datasets for each split
list_ds = []
start_idx = 0
for n_samples in list_n_samples:
end_idx = start_idx + n_samples
list_ds.append(
dataset.isel(
{samples_coordinate: shuffled_idcs[start_idx:end_idx]}
)
)
start_idx = end_idx
# Throw warning if a subset is empty
if any(len(ds.image_id) == 0 for ds in list_ds):
logger.warning("At least one of the subset datasets is empty.")
# Return subsets in the same order as the input list of fractions
# (argsort twice gives the inverse permutation)
idcs_sorted = np.argsort(list_fractions) # idcs to map input -> sorted
idcs_original = np.argsort(idcs_sorted) # idcs to map sorted -> input
return tuple(list_ds[i] for i in idcs_original)
class _SubsetDict(TypedDict):
"""Subset dictionary.
Each subset dictionary is made up of a list of ``group IDs`` ("ids") and
their total ``group count`` ("sum"). Used as a type definition for the
approximate subset sum algorithm.
Attributes
----------
sum : int
The total ``group count`` of the subset.
ids : list[int]
The list of ``group IDs`` in the subset.
"""
sum: int
ids: list[int]
def _approximate_subset_sum(
list_id_counts: list[tuple[int, int]], target: int, epsilon: float
) -> _SubsetDict:
"""Find a subset of the input list whose sum is maximum but below target.
The input is a list of pairs (``group IDs``, ``group count``). We want
to extract a subset of elements from this list such that their total
``group count`` is as close as possible but does not exceed the
``target`` value.
Parameters
----------
list_id_counts : list[tuple[int, int]]
The list of pairs (``group IDs``, ``group count``).
target : int
The target value for the total ``group count`` of the subset.
epsilon : float
The approximation tolerance for the subset sum as a fraction of 1.
When ``epsilon`` is 0, the algorithm finds the optimal subset for the
requested target value and input list. Larger values of ``epsilon``
result in faster computation but may yield subsets with a total
``group count`` further from the optimal.
Returns
-------
_SubsetDict
The subset dictionary.
Raises
------
Warning
If all groups in the input list have more samples than the target
value. In this case, the function returns an empty subset.
Notes
-----
The function uses a fully polynomial-time approximation scheme (FPTAS) to
approximately solve this subset sum problem. When ``epsilon`` is 0, it
finds the exact optimal subset. When ``epsilon`` > 0, it returns an
approximate solution guaranteed to be within ``epsilon`` times the optimal
sum from below. Using an ``epsilon`` value larger than 0 may be convenient
in cases with a large number of subsets for faster runtime.
The algorithm iteratively processes each element in the input list,
maintaining a list of candidate subsets whose total ``group counts`` falls
below the target. At each iteration, it removes near duplicate subsets
from the list of candidates ("trimming") to prevent exponential growth,
while ensuring the total approximation error stays within ``epsilon``.
Two subsets are near duplicates if their total ``group count`` is
sufficiently close.
Note that ``epsilon`` bounds the error from below relative to the optimal
subset sum, not the ``target``. If ``OPT`` is the best possible subset sum
below or equal to ``target``, the result is guaranteed to be
at least ``(1 - epsilon)*OPT``. E.g. for ``epsilon = 0.2``,
the result is guaranteed to be at least ``0.8*OPT``.
References
----------
.. [1] https://en.wikipedia.org/wiki/Subset_sum_problem#Fully-polynomial_time_approximation_scheme
.. [2] https://nerderati.com/bartering-for-beers-with-approximate-subset-sums/
"""
# Checks
if np.min([x[1] for x in list_id_counts]) > target:
logger.warning(
"All groups have more samples than the target value. "
"Returning empty subset."
)
return {"sum": 0, "ids": []}
# Initialize list of candidate subsets
list_subsets: list[_SubsetDict] = [{"sum": 0, "ids": []}]
# Loop thru list of (id, count) pairs
for id, count in list_id_counts:
# Add current (id, count) pair to each candidate subset in the list
# if the resulting subset sum is below the target.
list_subsets.extend(
[
{
"sum": subset["sum"] + count,
"ids": subset["ids"] + [id],
}
for subset in list_subsets
if subset["sum"] + count <= target
]
)
# Remove near-duplicate subsets in terms of total group count ("sum")
# At each iteration of the loop, trimming introduces a small error.
# The algorithm runs n iterations (one per item), so errors can
# compound. Using ``delta = epsilon / (2n)`` ensures that when errors
# compound over iterations, the total error stays within ``epsilon``.
list_subsets = _remove_near_duplicate_subsets(
list_subsets,
delta=float(epsilon) / (2 * len(list_id_counts)),
)
# Return the subset with highest sum but below the target
return list_subsets[-1]
def _remove_near_duplicate_subsets(
list_subsets: list[_SubsetDict], delta: float
) -> list[_SubsetDict]:
"""Remove near-duplicate subsets from the list in terms of their total sum.
Keeps only subsets whose sum is sufficiently larger than the previous
subset sum in ascending order. When two subsets have sums within
``delta``% of each other, retains the smaller one (which is visited first
after sorting).
Parameters
----------
list_subsets : list[_SubsetDict]
The list of candidate subsets. Each subset is a dictionary with a list
of ``group IDs`` ("ids") and their total ``group count`` ("sum").
delta : float
If two subsets are within ``delta``% of each other, they are considered
near duplicates and the smaller one is removed.
Returns
-------
list[_SubsetDict]
The list of subsets after trimming.
"""
# Ensure list of subsets is sorted by total sum, in ascending order
list_subsets = sorted(list_subsets, key=lambda x: x["sum"])
# Keep only subsets whose sum is delta% larger than the previous one
list_subsets_trimmed = [
list_subsets[0]
] # always retain the zero subset; [{"sum": 0, "ids": []}]
previous_subset_sum = 0
for subset in list_subsets[1:]: # do not trim zero subset
if subset["sum"] > previous_subset_sum * (1 + delta):
list_subsets_trimmed.append(subset)
previous_subset_sum = subset["sum"]
return list_subsets_trimmed