from typing import Any, Iterable, Optional, Sequence, Tuple, cast, get_args
from numpy import array, datetime64, logical_and, ndarray, nonzero
from numpy.typing import NDArray
from ..._backend import DatetimeIndexSplitter
from ..._docstrings import docs
from ...types import DatetimeIterable, DatetimeSplits, DatetimeTypes, Flex, Schedule, Span
from .._log_progress import LogProgressArg, handle_log_progress_arg
try:
from sklearn.model_selection import BaseCrossValidator # type: ignore[import-untyped]
except ModuleNotFoundError:
BaseCrossValidator = object
IndexTuple = Tuple[Sequence[int], Sequence[int]]
[docs]@docs
class ScikitLearnSplitter(BaseCrossValidator): # type: ignore[misc]
"""A scikit-learn compatible datetime splitter.
This class may be used to create temporal folds from heterogeneous/unaggregated data, typically used for training
models (e.g. on raw transaction data). If your data is a well-formed time series, consider using the
`TimeSeriesSplit`_ class from scikit-learn instead.
.. _TimeSeriesSplit: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.TimeSeriesSplit.html
.. note::
Test coverage is limited for this integration.
Please report issues to https://github.com/rsundqvist/rics/issues/new.
If a ``pandas`` type is passed to the :meth:`ScikitLearnSplitter.split`-method, the index will be used.
Args:
schedule: {schedule}
before: {before}
after: {after}
n_splits: {n_splits}
flex: {flex}
log_progress: {log_progress}
verify_xy: If ``True``, split X and y independently and verify that they are equal.
{USER_GUIDE}
"""
def __init__(
self,
schedule: Schedule,
*,
before: Span = "7d",
after: Span = 1,
n_splits: Optional[int] = None,
flex: Flex = "auto",
log_progress: LogProgressArg = False,
verify_xy: bool = True,
) -> None:
super().__init__()
self._splitter = DatetimeIndexSplitter(schedule, before, after=after, n_splits=n_splits, flex=flex)
self.log_progress = log_progress
self.verify_xy = verify_xy
[docs] def get_n_splits(self, X: DatetimeIterable = None, y: DatetimeIterable = None, groups: Any = None) -> int:
"""Returns the number of splitting iterations in the cross-validator.
Equivalent to ``len(list(split(X, y, groups))``.
Args:
X: Training data (features).
y: Target variable.
groups: Always ignored, exists for compatibility.
Returns:
Number of splits with given arguments.
Raises:
ValueError: If both `X` and `y` are ``None``.
ValueError: If splits of `X` and `y` are not equal when ``verify_xy=True``.
"""
_, splits = self._get_splits(X, y)
return len(splits)
[docs] def split(
self,
X: DatetimeIterable = None,
y: DatetimeIterable = None,
groups: Any = None,
) -> Iterable[IndexTuple]:
"""Generate indices to split data into training and test set.
Args:
X: Training data (features).
y: Target variable.
groups: Always ignored, exists for compatibility.
Yields:
The training/test set indices for that split.
Raises:
ValueError: If both `X` and `y` are ``None``.
ValueError: If splits of `X` and `y` are not equal when ``verify_xy=True``.
TypeError: If `X` or `y` have an ``index``-attribute, but index elements are not datetime-like.
"""
index, splits = self._get_splits(self._handle_pandas(X, "X"), self._handle_pandas(y, "y"))
for fold in handle_log_progress_arg(self.log_progress, splits=splits) or splits:
yield cast(
IndexTuple,
(
nonzero(logical_and(fold.start <= index, index < fold.mid))[0],
nonzero(logical_and(fold.mid <= index, index < fold.end))[0],
),
)
@staticmethod
def _handle_pandas(arg: Optional[DatetimeIterable], name: str) -> Optional[DatetimeIterable]:
if arg is None or not hasattr(arg, "index"):
return arg
index = arg.index
if type(arg.index[0]) in get_args(DatetimeTypes):
return cast(DatetimeIterable, index)
raise TypeError(f"{name}.index does not appear to be a datetime-like iterable.")
def _get_splits(
self,
X: DatetimeIterable = None,
y: DatetimeIterable = None,
) -> Tuple[NDArray[datetime64], DatetimeSplits]:
splits: Optional[DatetimeSplits] = None
y_splits: Optional[DatetimeSplits] = None
timestamps: Optional[Any] = None
if X is not None:
splits = self._splitter.get_splits(X)
timestamps = X
if y is not None:
y_splits = self._splitter.get_splits(y)
timestamps = y
if self.verify_xy and splits and y_splits and splits != y_splits:
raise ValueError("Splits of X and y are not equal.")
# Cast should not be needed, but MyPy things this is nullable.
if splits is None and y_splits is None:
raise ValueError("At least one of (X, y) must be given.")
timestamps = timestamps if isinstance(timestamps, ndarray) else array(timestamps)
if len(timestamps.shape) > 1:
raise NotImplementedError(f"shape {timestamps.shape} not supported")
return timestamps, cast(DatetimeSplits, splits or y_splits)