Source code for ethology.validators.utils

"""Utils for validating `ethology` objects."""

from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import wraps

import xarray as xr
from attrs import define, field


[docs] @define class ValidDataset(ABC): """An abstract base class for valid ``ethology`` datasets. This class validates that the input dataset: - is an xarray Dataset - contains all required dimensions - contains all required data variables - has the correct dimensions for each data variable Subclasses must define ``required_dims`` and ``required_data_vars`` attributes. Attributes ---------- dataset : xarray.Dataset The xarray dataset to validate. required_dims : set[str] A set of required dimension names. This attribute should be defined by any subclass inheriting from this class. required_data_vars : dict[str, set] A dictionary mapping data variable names to their required dimensions. This attribute should be defined by any subclass inheriting from this class. Raises ------ TypeError If the input is not an xarray Dataset. ValueError If the dataset is missing required data variables or dimensions, or if any required dimensions are missing for any data variable. Notes ----- The dataset can have other data variables and dimensions, but only the required ones are checked. """ dataset: xr.Dataset = field() # Subclasses should override these abstract properties @property @abstractmethod def required_dims(self) -> set: """Subclasses must provide a ``required_dims`` property.""" pass # pragma: no cover @property @abstractmethod def required_data_vars(self) -> dict[str, set]: """Subclasses must provide a ``required_data_vars`` property.""" pass # pragma: no cover # Validators @dataset.validator def _check_dataset_type(self, attribute, value): """Ensure the input is an xarray Dataset.""" if not isinstance(value, xr.Dataset): raise TypeError( f"Expected an xarray Dataset, but got {type(value)}." ) @dataset.validator def _check_required_data_variables(self, attribute, value): """Ensure the dataset has all required data variables.""" missing_vars = self.required_data_vars.keys() - set(value.data_vars) if missing_vars: raise ValueError( f"Missing required data variables: {sorted(missing_vars)}" ) @dataset.validator def _check_required_dimensions(self, attribute, value): """Ensure the dataset has all required dimensions.""" missing_dims = self.required_dims - set(value.dims) if missing_dims: raise ValueError( f"Missing required dimensions: {sorted(missing_dims)}" ) @dataset.validator def _check_dimensions_per_data_variable(self, attribute, value): """Ensure the dataset has all required dimensions.""" error_messages = [] for data_var, dims_per_data_var in self.required_data_vars.items(): missing_dims = dims_per_data_var - set( value.data_vars[data_var].coords ) if missing_dims: error_messages.append( f"data variable '{data_var}' is missing " f"dimensions {sorted(missing_dims)}" ) if error_messages: raise ValueError( "Some data variables are missing required dimensions:\n - " + "\n - ".join(error_messages) )
def _check_output(validator: type): """Return a decorator that validates the output of a function.""" def decorator(function: Callable) -> Callable: @wraps(function) # to preserve function metadata def wrapper(*args, **kwargs): result = function(*args, **kwargs) validator(result) return result return wrapper return decorator def _check_input(validator: type, input_index: int = 0): """Return a decorator that validates a specific input of a function. By default, the first input is validated. If the input index is larger than the number of inputs, no validation is performed. """ def decorator(function: Callable) -> Callable: @wraps(function) def wrapper(*args, **kwargs): if len(args) > input_index: validator(args[input_index]) result = function(*args, **kwargs) return result return wrapper return decorator