Source code for rics.ml.time_split._frontend._progress

import logging
from collections.abc import Callable, Iterable, MutableMapping, Sequence
from dataclasses import dataclass
from time import perf_counter
from typing import Any, Union

from rics.misc import get_by_full_name

from ..settings import log_split_progress as settings
from ..types import DatetimeSplitBounds
from ._to_string import _PrettyTimestamp

LoggerArg = Union[logging.Logger, logging.LoggerAdapter, str]  # type: ignore[type-arg]


[docs] def log_split_progress( splits: Sequence[DatetimeSplitBounds], *, logger: LoggerArg = "rics.ml.time_split", start_level: int = logging.INFO, end_level: int = logging.INFO, extra: dict[str, Any] | None = None, ) -> Iterable[DatetimeSplitBounds]: """Log iteration progress over `splits` using `logger`. Args: splits: Splits to iterate over. logger: Logger or logger name to use. start_level: Log level to use for the :attr:`fold-begin message <.settings.log_split_progress.START_MESSAGE>`. end_level: Log level to use for the :attr:`fold-end message <.settings.log_split_progress.END_MESSAGE>`. extra: User-defined `extra`-arguments to use when logging, merged with progress-related extras. Will be available to all messages as well as the ``fold`` key. This argument is mutable; changes made to `extra` will be reflected in logged records. Returns: An iterable over `splits`. Examples: Basic usage. >>> from rics.ml.time_split import split, log_split_progress >>> splits = split("36h", available=("2023-08-10", "2023-08-19")) >>> tracked_splits = log_split_progress( ... splits, logger="progress", start_level=logging.DEBUG ... ) >>> list(tracked_splits) # doctest: +SKIP [progress:DEBUG] Begin fold 1/2: ('2023-08-11' <= [schedule: '2023-08-16' (Wednesday)] < '2023-08-17 12:00:00'). [progress:INFO] Finished fold 1/2 [schedule: '2023-08-16' (Wednesday)] after 5m 18s. [progress:DEBUG] Begin fold 2/2: ('2023-08-12 12:00:00' <= [schedule: '2023-08-17 12:00:00' (Thursday)] < '2023-08-19'). [progress:INFO] Finished fold 2/2 [schedule: '2023-08-17 12:00:00' (Thursday)] after 4m 3s. """ logger = logging.getLogger(logger) if isinstance(logger, str) else logger if isinstance(logger, logging.LoggerAdapter) and not hasattr(logger, "merge_extra"): # Backport of https://github.com/python/cpython/pull/107292 logger = _MergingLoggerAdapter(logger.logger, logger.extra) track = _ProgressTracker( logger=logger, fold_format=settings.FOLD_FORMAT, start_level=start_level, start_message=settings.START_MESSAGE, end_level=end_level, end_message=settings.END_MESSAGE, seconds_formatter=settings.SECONDS_FORMATTER if callable(settings.SECONDS_FORMATTER) else get_by_full_name(settings.SECONDS_FORMATTER), user_extra=extra or {}, ) return track(splits)
@dataclass(frozen=True) class _ProgressTracker: logger: logging.Logger | logging.LoggerAdapter # type: ignore[type-arg] fold_format: str start_level: int start_message: str end_level: int end_message: str seconds_formatter: Callable[[float], str] user_extra: dict[str, Any] def __call__(self, splits: Sequence[DatetimeSplitBounds]) -> Iterable[DatetimeSplitBounds]: n_splits = len(splits) for n, split in enumerate(splits, start=1): extra: dict[str, str | float | int | bool] = dict( n=n, n_splits=n_splits, start=split.start.isoformat(), mid=split.mid.isoformat(), end=split.end.isoformat(), **self.user_extra, ) kwargs: dict[str, Any] = dict( n=n, n_splits=n_splits, start=_PrettyTimestamp(split.start), mid=_PrettyTimestamp(split.mid), end=_PrettyTimestamp(split.end), **self.user_extra, ) kwargs.update(fold=self.fold_format.format(**kwargs)) self.logger.log(self.start_level, self.start_message.format(**kwargs), extra=extra) # Yield split and count user time. start = perf_counter() yield split seconds = round(perf_counter() - start, 6) kwargs.update( seconds=seconds, formatted_seconds=self.seconds_formatter(seconds), ) extra.update(seconds=seconds) self.logger.log(self.end_level, self.end_message.format(**kwargs), extra=extra) class _MergingLoggerAdapter(logging.LoggerAdapter): # type: ignore[type-arg] def process(self, msg: Any, kwargs: MutableMapping[str, Any]) -> tuple[Any, MutableMapping[str, Any]]: """See https://github.com/python/cpython/pull/107292.""" kwargs["extra"] = {**self.extra, **kwargs["extra"]} if "extra" in kwargs and self.extra else self.extra return msg, kwargs