Source code for rics.performance.plot._params

from collections.abc import Hashable, Iterable, Mapping
from dataclasses import dataclass, field, fields
from typing import Any, ClassVar, Literal, Self

import numpy as np
import pandas as pd

from ...collections.dicts import compute_if_absent
from ..types import ResultsDict
from .types import Candidate, FuncOrData, Kind, TestData, Unit

FUNC: Candidate = "Candidate"
DATA: TestData = "Test data"


[docs] @dataclass(frozen=True, kw_only=True) class CatplotParams: data: pd.DataFrame = field(repr=False) x: FuncOrData y: str hue: FuncOrData kind: Kind names: list[str] = field(metadata={"skip": True}) user_kwargs: Mapping[str, Any] = field(metadata={"skip": True}) DEFAULTS: ClassVar[dict[str, Any]] = {"errorbar": "sd", "estimator": "min", "aspect": 2} DEFAULTS_BY_KIND: ClassVar[dict[Kind, dict[str, Any]]] = { "bar": {"capsize": 0.2}, "point": {"capsize": 0.2}, } COMPUTED_DEFAULTS: ClassVar[tuple[str, ...]] = ("order", "hue_order", "log_scale")
[docs] @classmethod def reserved_keys(cls) -> set[str]: return {f.name for f in fields(cls) if not f.metadata.get("skip")}
def __post_init__(self) -> None: if keys := self.reserved_keys().intersection(self.user_kwargs): msg = f"Bad `kwargs`: {keys=} are reserved." raise ValueError(msg)
[docs] @classmethod def make( cls, run_results: ResultsDict | pd.DataFrame, *, x: Literal["candidate", "data"] | None = None, unit: Unit | None = None, kind: Kind = "bar", names: Iterable[str] = (), **kwargs: Any, ) -> Self: """Create instance from run results.""" df = _make_df(run_results) # MyPy thinks these are both string x_col, hue_col = (DATA, FUNC) if _is_data_x(x, df=df) else (FUNC, DATA) return cls( data=df, x=x_col, # type: ignore[arg-type] y=_resolve_y(unit, df=df), hue=hue_col, # type: ignore[arg-type] kind=kind, names=list(names), user_kwargs=dict(kwargs), )
[docs] def to_kwargs(self) -> dict[str, Any]: """Convert to :func:`seaborn.catplot` keyword arguments.""" kwargs = dict(self.user_kwargs) for key in self.reserved_keys(): kwargs[key] = getattr(self, key) for key, default in self.DEFAULTS.items(): kwargs.setdefault(key, default) for key, default in self.DEFAULTS_BY_KIND.get(self.kind, {}).items(): kwargs.setdefault(key, default) for key in self.COMPUTED_DEFAULTS: compute_if_absent(kwargs, key, self._compute_default) if self.names: self._handle_row_col(kwargs) return kwargs
@property def want_log_scale_hack(self) -> bool: return self.kind == "bar" user_log_scale = self.user_kwargs.get("log_scale") if not user_log_scale: return True # No explicit user choice - apply # TODO(me): bool, number, or pair of bools or numbers # and isinstance(self.user_kwargs.get("log_scale"), tuple) def _compute_default(self, key: str) -> Any: df = self.data match key: case "order": return df[self.x].unique().tolist() case "hue_order": return df[self.hue].unique().tolist() case "log_scale": column = next(filter(lambda c: c.startswith("Time ["), df.columns)) means = df.groupby([DATA, FUNC], observed=True)[column].mean() return (False, True) if means.max() / means.min() > 20 else None # noqa: PLR2004 raise KeyError(f"Bad {key=}. Possible choices={self.COMPUTED_DEFAULTS}.") def _handle_row_col(self, kwargs: dict[str, Any]) -> None: names = self.names updates = {} # We could print something like f"{label}: [row]", but this could add a lot of additional text to the legend # or x-axis. I think this is a better way to do it. hidden = {names.index(name) for key in ("row", "col") if (name := kwargs.get(key)) in names} formatters = {i: f"{name} = {{}}" for i, name in enumerate(names) if i not in hidden} def format_label(label: tuple[Hashable]) -> str: if not isinstance(label, tuple) or len(label) != len(names): raise TypeError(f"Cannot format {label=} using {names=}.") g = (template.format(lv) for i, lv in enumerate(label) if (template := formatters.get(i))) return " | ".join(g) order_key = "hue_order" if self.hue == DATA else "order" if old_order := kwargs.get(order_key): updates[order_key] = np.unique([format_label(label) for label in old_order]).tolist() # assert sorted(self.data.columns.intersection(names)) == sorted(names) data = kwargs["data"].copy() data[DATA] = data[DATA].map(format_label) updates["data"] = data # Don't write updates until we're done; allows caller to skip if needed. kwargs.update(updates)
def _make_df(run_results: ResultsDict | pd.DataFrame) -> pd.DataFrame: from rics.performance import to_dataframe df = to_dataframe(run_results) if isinstance(run_results, dict) else run_results.copy() as_category = [DATA, FUNC] df[as_category] = df[as_category].astype("category") return df def _is_data_x(x_col: Literal["candidate", "data"] | None, *, df: pd.DataFrame) -> bool: if x_col is None: n_data, n_func = df[[DATA, FUNC]].nunique() return n_data > n_func # type: ignore[no-any-return] else: return x_col.lower().startswith("d") def _resolve_y(unit: Unit | None, *, df: pd.DataFrame) -> str: if unit is None: return _compute_nice_y(df) if unit == "us": unit = "μs" y = f"Time [{unit}]" if y not in df: msg = f"Bad {unit=}; column '{y}' not present in data." raise TypeError(msg) return y def _compute_nice_y(df: pd.DataFrame) -> str: """Pick the unit with the most "human" scale; whole numbers around one hundred.""" from numpy import log10 columns = [c for c in df.columns if c.startswith("Time [")] means = df.groupby([DATA, FUNC], observed=True)[columns].mean() residuals = log10(means) - 2 avg_residual_by_time_column = residuals.mean(axis="index") y = avg_residual_by_time_column.abs().idxmin() assert isinstance(y, str) # noqa: S101 return y