Source code for rics.ml.time_split.integration.sklearn._impl

from typing import Any, Iterable, Optional, Sequence, Tuple, cast, get_args

from numpy import array, datetime64, logical_and, ndarray, nonzero
from numpy.typing import NDArray

from ..._backend import DatetimeIndexSplitter
from ..._docstrings import docs
from ...types import DatetimeIterable, DatetimeSplits, DatetimeTypes, Flex, Schedule, Span
from .._log_progress import LogProgressArg, handle_log_progress_arg

try:
    from sklearn.model_selection import BaseCrossValidator  # type: ignore[import-untyped]
except ModuleNotFoundError:
    BaseCrossValidator = object


IndexTuple = Tuple[Sequence[int], Sequence[int]]


[docs]@docs class ScikitLearnSplitter(BaseCrossValidator): # type: ignore[misc] """A scikit-learn compatible datetime splitter. This class may be used to create temporal folds from heterogeneous/unaggregated data, typically used for training models (e.g. on raw transaction data). If your data is a well-formed time series, consider using the `TimeSeriesSplit`_ class from scikit-learn instead. .. _TimeSeriesSplit: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.TimeSeriesSplit.html .. note:: Test coverage is limited for this integration. Please report issues to https://github.com/rsundqvist/rics/issues/new. If a ``pandas`` type is passed to the :meth:`ScikitLearnSplitter.split`-method, the index will be used. Args: schedule: {schedule} before: {before} after: {after} n_splits: {n_splits} flex: {flex} log_progress: {log_progress} verify_xy: If ``True``, split X and y independently and verify that they are equal. {USER_GUIDE} """ def __init__( self, schedule: Schedule, *, before: Span = "7d", after: Span = 1, n_splits: Optional[int] = None, flex: Flex = "auto", log_progress: LogProgressArg = False, verify_xy: bool = True, ) -> None: super().__init__() self._splitter = DatetimeIndexSplitter(schedule, before, after=after, n_splits=n_splits, flex=flex) self.log_progress = log_progress self.verify_xy = verify_xy
[docs] def get_n_splits(self, X: DatetimeIterable = None, y: DatetimeIterable = None, groups: Any = None) -> int: """Returns the number of splitting iterations in the cross-validator. Equivalent to ``len(list(split(X, y, groups))``. Args: X: Training data (features). y: Target variable. groups: Always ignored, exists for compatibility. Returns: Number of splits with given arguments. Raises: ValueError: If both `X` and `y` are ``None``. ValueError: If splits of `X` and `y` are not equal when ``verify_xy=True``. """ _, splits = self._get_splits(X, y) return len(splits)
[docs] def split( self, X: DatetimeIterable = None, y: DatetimeIterable = None, groups: Any = None, ) -> Iterable[IndexTuple]: """Generate indices to split data into training and test set. Args: X: Training data (features). y: Target variable. groups: Always ignored, exists for compatibility. Yields: The training/test set indices for that split. Raises: ValueError: If both `X` and `y` are ``None``. ValueError: If splits of `X` and `y` are not equal when ``verify_xy=True``. TypeError: If `X` or `y` have an ``index``-attribute, but index elements are not datetime-like. """ index, splits = self._get_splits(self._handle_pandas(X, "X"), self._handle_pandas(y, "y")) for fold in handle_log_progress_arg(self.log_progress, splits=splits) or splits: yield cast( IndexTuple, ( nonzero(logical_and(fold.start <= index, index < fold.mid))[0], nonzero(logical_and(fold.mid <= index, index < fold.end))[0], ), )
@staticmethod def _handle_pandas(arg: Optional[DatetimeIterable], name: str) -> Optional[DatetimeIterable]: if arg is None or not hasattr(arg, "index"): return arg index = arg.index if type(arg.index[0]) in get_args(DatetimeTypes): return cast(DatetimeIterable, index) raise TypeError(f"{name}.index does not appear to be a datetime-like iterable.") def _get_splits( self, X: DatetimeIterable = None, y: DatetimeIterable = None, ) -> Tuple[NDArray[datetime64], DatetimeSplits]: splits: Optional[DatetimeSplits] = None y_splits: Optional[DatetimeSplits] = None timestamps: Optional[Any] = None if X is not None: splits = self._splitter.get_splits(X) timestamps = X if y is not None: y_splits = self._splitter.get_splits(y) timestamps = y if self.verify_xy and splits and y_splits and splits != y_splits: raise ValueError("Splits of X and y are not equal.") # Cast should not be needed, but MyPy things this is nullable. if splits is None and y_splits is None: raise ValueError("At least one of (X, y) must be given.") timestamps = timestamps if isinstance(timestamps, ndarray) else array(timestamps) if len(timestamps.shape) > 1: raise NotImplementedError(f"shape {timestamps.shape} not supported") return timestamps, cast(DatetimeSplits, splits or y_splits)