Source code for rics.performance._multi_case_timer

import functools
import logging
import warnings
from collections.abc import Callable, Collection, Hashable, Iterator, Mapping
from dataclasses import dataclass
from inspect import Parameter, signature
from time import perf_counter
from timeit import Timer
from typing import Any, ClassVar, Generic, TypeAlias

from rics.misc import format_kwargs, tname
from rics.strings import format_perf_counter
from rics.strings import format_seconds as fmt_time

from .types import CandFunc, DataFunc, DataType, ResultsDict, Ts

UNRELIABLE_RESULTS_LIMIT = 1e-6  # Prevent spurious "4x" warnings.

CandidateMethodArg: TypeAlias = Mapping[str, CandFunc[DataType]] | Collection[CandFunc[DataType]] | CandFunc[DataType]
TestDataArg: TypeAlias = Mapping[Any, DataType] | Collection[DataType]


[docs] class MultiCaseTimer(Generic[DataType, *Ts]): """Performance testing implementation for multiple candidates and data sets. Test data: * Typically a dict ``{label: data}`` to evaluate candidates. * Other collections are converted to ``dict`` using :meth:`process_test_data`. String label will then be based on sample data. * Labels may also be ``tuple``. This may then be used to plot different categories of data in different facets; see the :func:`.plot_run` function with the `names` argument. * For non-dict inputs, string labels will be generated automatically. * If `test_data` is :py:func:`callable`, test data will be generated from the `case_args`. * The `case_args` will be passed as positional arguments. * The `case_args` will be used as the output labels when using :meth:`run` (similar to the ``setup`` option provided by the built-in :py:mod:`timeit` module). Data access time is *not* measured by the ``run`` method. Args: candidate_method: A dict ``{label: function}``. Alternatively, you may pass a collection of functions or a single function. test_data: A ``{label: data}`` to evaluate candidates on. You may also pass a list of data, which will be converted to a dict as above. Data may also be generated by passing a callable. case_args: These are positional arguments for the `test_data` callable. kwargs: Shared keyword arguments for the `test_data` callable. Raises: TypeError: If `args` or `kwargs` are set when `test_data` is not a callable. ValueError: If `args` is empty and `test_data` is a callable. """ LOGGER: ClassVar[logging.Logger | logging.LoggerAdapter[Any]] = logging.getLogger(__package__) """Class logger instance.""" def __init__( self, candidate_method: CandidateMethodArg[DataType], test_data: TestDataArg[DataType] | DataFunc[*Ts, DataType], # DataFunc[DataFuncP, DataType] *, case_args: Collection[tuple[*Ts]] | None = None, kwargs: Any | None = None, ) -> None: self._candidates = self.process_candidates(candidate_method) self._data: dict[Hashable, DataType] | GeneratedData[DataType, *Ts] if callable(test_data): if not case_args: raise ValueError("No case data given.") cases = [*case_args] if len(cases) != len({*cases}): raise ValueError("Cases are not unique.") self._data = GeneratedData(test_data, cases=cases, kwargs=kwargs, logger=self.LOGGER) else: if case_args or kwargs: msg = "Cannot pass `case_args` or `kwargs` when `test_data` is not a callable." raise TypeError(msg) self._data = self.process_test_data(test_data)
[docs] @classmethod def process_candidates(cls, candidates: CandidateMethodArg[DataType]) -> dict[str, CandFunc[DataType]]: """Convert input candidates to the internal format.""" rv = cls._process_candidates(candidates) if rv: return rv raise ValueError("No candidates given.") # pragma: no cover
[docs] @classmethod def process_test_data(cls, test_data: TestDataArg[DataType]) -> dict[Hashable, DataType]: """Convert input test data to the internal format.""" rv = {**test_data} if isinstance(test_data, Mapping) else cls._dict_from_collection(test_data) if rv: return rv raise ValueError("No case data given.") # pragma: no cover
[docs] def derive_names(self) -> list[str]: """Derive names argument. Raises: TypeError: If `test_data` is not callable. """ if not isinstance(self._data, GeneratedData): raise TypeError("Cannot derive names without callable `test_data`.") return self._data.derive_names()
@property def is_data_generated(self) -> bool: """Returns ``True`` if the `test_data` is callable.""" return isinstance(self._data, GeneratedData)
[docs] def run( self, *, time_per_candidate: float = 6.0, repeat: int = 5, number: int | None = None, skip_if: Callable[["SkipIfParams[DataType, *Ts]"], bool] | None = None, progress: bool = False, ) -> ResultsDict: """Run for all cases. Note that the test case variant data isn't included in the expected runtime computation, so increasing the amount of test data variants (at initialization) will reduce the amount of times each candidate is evaluated. Args: time_per_candidate: Minimum runtime per repetition and candidate label. Ignored if `number` is set. repeat: Number of times to repeat for all candidates per data label. number: Number of times to execute each candidate, per repetition. skip_if: A callable ``(skip_if) -> bool``; see the :class:`params <SkipIfParams>` type. progress: If ``True``, display a progress bar. Requires ``tqdm``. Examples: If `repeat=5` and `time_per_candidate=3` for an instance with and 2 candidates, the total runtime will be approximately ``5 * 3 * 2 = 30`` seconds. Returns: A dict `run_results` on the form ``{candidate_label: {data_label: [runtime, ...]}}``. Raises: ValueError: If the total expected runtime exceeds `max_expected_runtime`. Notes: * Precomputed runtime is inaccurate for functions where a single call are longer than `time_per_candidate`. See Also: The :py:class:`timeit.Timer` class which this implementation depends on. """ logger = self.LOGGER n_cand = len(self._candidates) n_data = len(self._data) total = n_cand * n_data logger.debug("Begin evaluating %i combinations: %i candidates and %i test cases.", total, n_cand, n_data) per_candidate_number = self._compute_number_of_iterations(number, repeat, time_per_candidate, progress) if progress: from tqdm.auto import tqdm pbar = tqdm(total=total) else: pbar = None i = 0 run_results: ResultsDict = {} for candidate_label, func in self._candidates.items(): candidate_number, candidate_est_time = per_candidate_number[candidate_label] run_results[candidate_label] = candidate_results = {} logger.info(f"Evaluate candidate {candidate_label!r} {repeat}x{candidate_number} times per datum..") for data_label, test_data in self._data.items(): i += 1 if pbar: pbar.desc = f"{candidate_label}({data_label})" pbar.refresh() if skip_if: skip_if_params: SkipIfParams[DataType, *Ts] = SkipIfParams( candidate=func, candidate_label=candidate_label, data=test_data, data_label=data_label, est_time=None if candidate_est_time is None else candidate_est_time * repeat, results_so_far=run_results, ) if skip_if(skip_if_params): if pbar: pbar.update() logger.debug(f"Skip combination {i}/{total}: {candidate_label!r} @ {data_label!r}.") continue logger.debug(f"Start evaluating combination {i}/{total}: {candidate_label!r} @ {data_label!r}.") raw_timings = self._get_raw_timings(func, test_data, repeat, candidate_number) timings = [dt / candidate_number for dt in raw_timings] # Same heuristic as the IPython cell magic. best = min(timings) worst = max(timings) if best > 0 and worst >= best * 4 and worst > UNRELIABLE_RESULTS_LIMIT: t = (candidate_label, data_label) warnings.warn( f"Results may be unreliable for {t}. The worst time {fmt_time(worst)} " f"was ~{worst / best:.1f} times slower than the best time ({fmt_time(best)}).", UserWarning, stacklevel=1, ) candidate_results[data_label] = timings if pbar: pbar.update() return run_results
@staticmethod def _get_raw_timings(func: CandFunc[DataType], test_data: DataType, repeat: int, number: int) -> list[float]: """Exists so that it can be overridden for testing.""" return Timer(lambda: func(test_data)).repeat(repeat, number) def _compute_number_of_iterations( self, number: int | None, repeat: int, time_allocation: float, progress: bool, ) -> dict[str, tuple[int, float]] | dict[str, tuple[int, None]]: logger = self.LOGGER if isinstance(number, int): return {c: (number, None) for c in self._candidates} logger.debug("Computing number of iterations with repeat=%i and time_allocation=%f.", repeat, time_allocation) start = perf_counter() if progress: from tqdm.auto import tqdm pbar = tqdm(total=len(self._candidates)) else: pbar = None candidate_numbers = {} for label in self._candidates: if pbar: pbar.desc = f"autorange('{label}')" pbar.refresh() number, time = self._autonumber(time_allocation, label) candidate_numbers[label] = number, time logger.debug("Candidate number for candidate=%r: %i (time=%f).", label, number, time) if pbar: pbar.update() if pbar: pbar.clear() logger.info( f"Computed shared number for {len(candidate_numbers)} candidates in {format_perf_counter(start)}: " f"{candidate_numbers}." ) return candidate_numbers def _autonumber(self, time_allocation: float, candidate_label: str) -> tuple[int, float]: """Based on Timer.autorange().""" i = 1 while True: for j in 1, 2, 3, 5: number = i * j total_time_taken = 0.0 for data in self._data.values(): func = self._candidates[candidate_label] partial = functools.partial(func, data) total_time_taken += Timer(partial).timeit(number) if total_time_taken >= time_allocation: if total_time_taken > 1: total_time_taken = round(total_time_taken, 2) return number, total_time_taken i *= 10 @staticmethod def _process_candidates(candidates: CandidateMethodArg[DataType]) -> dict[str, CandFunc[DataType]]: if isinstance(candidates, Mapping): return {**candidates} if callable(candidates): return {tname(candidates, prefix_classname=True): candidates} def make_label(a: Any) -> str: name = tname(a, prefix_classname=True) return name.removeprefix("candidate_") labeled_candidates = {make_label(c): c for c in candidates} if len(labeled_candidates) != len(candidates): raise ValueError( f"Derived names for input {candidates=} are not unique. Use a dict to assign candidate names." ) return labeled_candidates @staticmethod def _dict_from_collection(test_data: Collection[DataType]) -> dict[Hashable, DataType]: result: dict[Hashable, DataType] = {} for data in test_data: s = str(data) if isinstance(data, (bool, float, int, str, tuple)): key = s else: key = f"{s[:29]}..." if len(s) > 32 else s # noqa: PLR2004 key = f"Sample data: '{key}'" result[key] = data return result
class GeneratedData(Generic[DataType, *Ts]): def __init__( self, func: DataFunc[*Ts, DataType], cases: list[tuple[*Ts]], *, kwargs: dict[str, Any] | None, logger: logging.Logger | logging.LoggerAdapter[Any], ) -> None: self._func = func self._cases = cases self._kwargs = kwargs or {} if hasattr(logger, "getChild"): logger = logger.getChild("data") self._logger = logger def derive_names(self) -> list[str]: case_args = self._cases n = len(case_args[0]) parameters = [ name for name, parameter in signature(self._func).parameters.items() if parameter.kind in {Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD} ] if len(parameters) < n: msg = ( f"Could not derive names for test_data=" f"{tname(self._func, prefix_classname=True)}({', '.join(parameters)}, ...)." f" Expected at least {n} positional parameters since {case_args[0]=}." f" Got {len(parameters)}: {parameters}." ) raise RuntimeError(msg) return parameters[:n] def __len__(self) -> int: return len(self._cases) def keys(self) -> list[tuple[*Ts]]: return self._cases def values(self) -> Collection[DataType]: return [data for _, data in self.items()] def items(self) -> Iterator[tuple[tuple[*Ts], DataType]]: logger = self._logger kwargs = self._kwargs func = self._func logger_enabled = logger.isEnabledFor(logging.DEBUG) if logger_enabled: name = tname(func, prefix_classname=True) total = len(self._cases) for n, case in enumerate(self._cases, start=1): data = func(*case, **kwargs) if logger_enabled: args = ", ".join(map(repr, case)) if kwargs: args = args + ", " + format_kwargs(kwargs) logger.debug(f"Yield case {n}/{total}: {name}({args})={tname(data, prefix_classname=True)}.") yield case, data
[docs] @dataclass(frozen=True) class SkipIfParams(Generic[DataType, *Ts]): """Data type for a `skip_if` predicate.""" candidate: CandFunc[DataType] """Current candidate function.""" candidate_label: str """Candidate label.""" data: DataType """Current data.""" data_label: Hashable | tuple[*Ts] """Data label.""" est_time: float | None """Estimated time to finish all repetitions. Only when `number` is derived.""" results_so_far: ResultsDict """A snapshot timing values."""