Source code for rics.ml.time_split.integration.sklearn._impl

from collections.abc import Iterable, Sequence
from typing import Any, cast, get_args

from numpy import array, datetime64, logical_and, ndarray, nonzero
from numpy.typing import NDArray

from ..._backend import DatetimeIndexSplitter
from ..._docstrings import docs
from ...types import (
    DatetimeIterable,
    DatetimeSplits,
    DatetimeTypes,
    Flex,
    Schedule,
    Span,
)
from .._log_progress import LogProgressArg, handle_log_progress_arg

try:
    from sklearn.model_selection import BaseCrossValidator  # type: ignore[import-untyped]

except ModuleNotFoundError:
    BaseCrossValidator = object

IndexTuple = tuple[Sequence[int], Sequence[int]]


[docs] @docs class ScikitLearnSplitter(BaseCrossValidator): # type: ignore[misc] """A scikit-learn compatible datetime splitter. This class may be used to create temporal folds from heterogeneous/unaggregated data, typically used for training models (e.g. on raw transaction data). If your data is a well-formed time series, consider using the `TimeSeriesSplit`_ class from scikit-learn instead. .. _TimeSeriesSplit: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.TimeSeriesSplit.html If a ``pandas`` type is passed to the :meth:`ScikitLearnSplitter.split`-method, the index will be used. Args: schedule: {schedule} before: {before} after: {after} step: {step} n_splits: {n_splits} flex: {flex} log_progress: {log_progress} verify_xy: If ``True``, split X and y independently and verify that they are equal. {USER_GUIDE} """ def __init__( self, schedule: Schedule, *, before: Span = "7d", after: Span = 1, n_splits: int | None = None, flex: Flex = "auto", step: int = 1, log_progress: LogProgressArg = False, verify_xy: bool = True, ) -> None: super().__init__() self._splitter = DatetimeIndexSplitter( schedule, before=before, after=after, step=step, n_splits=n_splits, flex=flex, ) self.log_progress = log_progress self.verify_xy = verify_xy
[docs] def get_n_splits( self, X: DatetimeIterable | None = None, y: DatetimeIterable | None = None, groups: Any = None ) -> int: """Returns the number of splitting iterations in the cross-validator. Equivalent to ``len(list(split(X, y, groups))``. Args: X: Training data (features). y: Target variable. groups: Always ignored, exists for compatibility. Returns: Number of splits with given arguments. Raises: ValueError: If both `X` and `y` are ``None``. ValueError: If splits of `X` and `y` are not equal when ``verify_xy=True``. """ _, splits = self._get_splits(X, y) return len(splits)
[docs] def split( self, X: DatetimeIterable | None = None, y: DatetimeIterable | None = None, groups: Any = None, ) -> Iterable[IndexTuple]: """Generate indices to split data into training and test set. Args: X: Training data (features). y: Target variable. groups: Always ignored, exists for compatibility. Yields: The training/test set indices for that split. Raises: ValueError: If both `X` and `y` are ``None``. ValueError: If splits of `X` and `y` are not equal when ``verify_xy=True``. TypeError: If `X` or `y` have an ``index``-attribute, but index elements are not datetime-like. """ index, splits = self._get_splits(self._handle_pandas(X, "X"), self._handle_pandas(y, "y")) for fold in handle_log_progress_arg(self.log_progress, splits=splits) or splits: yield cast( IndexTuple, ( nonzero(logical_and(fold.start <= index, index < fold.mid))[0], nonzero(logical_and(fold.mid <= index, index < fold.end))[0], ), )
@staticmethod def _handle_pandas(arg: DatetimeIterable | None, name: str) -> DatetimeIterable | None: if arg is None or not hasattr(arg, "index"): return arg index = arg.index if type(arg.index[0]) in get_args(DatetimeTypes): return cast(DatetimeIterable, index) raise TypeError(f"{name}.index does not appear to be a datetime-like iterable.") def _get_splits( self, X: DatetimeIterable | None = None, y: DatetimeIterable | None = None, ) -> tuple[NDArray[datetime64], DatetimeSplits]: splits: DatetimeSplits | None = None y_splits: DatetimeSplits | None = None timestamps: Any | None = None if X is not None: splits = self._splitter.get_splits(X) timestamps = X if y is not None: y_splits = self._splitter.get_splits(y) timestamps = y if self.verify_xy and splits and y_splits and splits != y_splits: raise ValueError("Splits of X and y are not equal.") # Cast should not be needed, but MyPy things this is nullable. if splits is None and y_splits is None: raise ValueError("At least one of (X, y) must be given.") timestamps = timestamps if isinstance(timestamps, ndarray) else array(timestamps) if len(timestamps.shape) > 1: raise NotImplementedError(f"shape {timestamps.shape} not supported") return timestamps, cast(DatetimeSplits, splits or y_splits)