Source code for rics.performance.plot._params

from dataclasses import dataclass, field, fields
from typing import Any, ClassVar, Hashable, Iterable, Literal, Mapping, Self

import numpy as np
import pandas as pd

from ...collections.dicts import compute_if_absent
from ..types import ResultsDict
from .types import Candidate, FuncOrData, Kind, TestData, Unit

FUNC: Candidate = "Candidate"
DATA: TestData = "Test data"


[docs] @dataclass(frozen=True, kw_only=True) class CatplotParams: data: pd.DataFrame = field(repr=False) x: FuncOrData y: str hue: FuncOrData kind: Kind names: list[str] = field(metadata={"skip": True}) user_kwargs: Mapping[str, Any] = field(metadata={"skip": True}) DEFAULTS: ClassVar[dict[str, Any]] = {"errorbar": "sd", "estimator": "min", "aspect": 2} DEFAULTS_BY_KIND: ClassVar[dict[Kind, dict[str, Any]]] = { "bar": {"capsize": 0.2}, "point": {"capsize": 0.2}, } COMPUTED_DEFAULTS: ClassVar[tuple[str, ...]] = ("order", "hue_order", "log_scale")
[docs] @classmethod def reserved_keys(cls) -> set[str]: return {f.name for f in fields(cls) if not f.metadata.get("skip")}
def __post_init__(self) -> None: if keys := self.reserved_keys().intersection(self.user_kwargs): msg = f"Bad `kwargs`: {keys=} are reserved." raise ValueError(msg)
[docs] @classmethod def make( cls, run_results: ResultsDict | pd.DataFrame, *, x: Literal["candidate", "data"] | None = None, unit: Unit | None = None, kind: Kind = "bar", names: Iterable[str] = (), **kwargs: Any, ) -> Self: """Create instance from run results.""" df = _make_df(run_results) # MyPy thinks these are both string x_col, hue_col = (DATA, FUNC) if _is_data_x(x, df=df) else (FUNC, DATA) return cls( data=df, x=x_col, # type: ignore[arg-type] y=_resolve_y(unit, df=df), hue=hue_col, # type: ignore[arg-type] kind=kind, names=list(names), user_kwargs=dict(kwargs), )
[docs] def to_kwargs(self) -> dict[str, Any]: """Convert to :func:`seaborn.catplot` keyword arguments.""" kwargs = dict(self.user_kwargs) for key in self.reserved_keys(): kwargs[key] = getattr(self, key) for key, default in self.DEFAULTS.items(): kwargs.setdefault(key, default) for key, default in self.DEFAULTS_BY_KIND.get(self.kind, {}).items(): kwargs.setdefault(key, default) for key in self.COMPUTED_DEFAULTS: compute_if_absent(kwargs, key, self._compute_default) if self.names: self._handle_row_col(kwargs) return kwargs
@property def want_log_scale_hack(self) -> bool: return self.kind == "bar" user_log_scale = self.user_kwargs.get("log_scale") if not user_log_scale: return True # No explicit user choice - apply # TODO(me): bool, number, or pair of bools or numbers # and isinstance(self.user_kwargs.get("log_scale"), tuple) def _compute_default(self, key: str) -> Any: df = self.data match key: case "order": return df[self.x].unique().tolist() case "hue_order": return df[self.hue].unique().tolist() case "log_scale": column = next(filter(lambda c: c.startswith("Time ["), df.columns)) means = df.groupby([DATA, FUNC], observed=True)[column].mean() return (False, True) if means.max() / means.min() > 20 else None # noqa: PLR2004 raise KeyError(f"Bad {key=}. Possible choices={self.COMPUTED_DEFAULTS}.") def _handle_row_col(self, kwargs: dict[str, Any]) -> None: names = self.names updates = {} # We could print something like f"{label}: [row]", but this could add a lot of additional text to the legend # or x-axis. I think this is a better way to do it. hidden = {names.index(name) for key in ("row", "col") if (name := kwargs.get(key)) in names} formatters = {i: f"{name} = {{}}" for i, name in enumerate(names) if i not in hidden} def format_label(label: tuple[Hashable]) -> str: if not isinstance(label, tuple) or len(label) != len(names): raise TypeError(f"Cannot format {label=} using {names=}.") g = (template.format(lv) for i, lv in enumerate(label) if (template := formatters.get(i))) return " | ".join(g) order_key = "hue_order" if self.hue == DATA else "order" if old_order := kwargs.get(order_key): updates[order_key] = np.unique([format_label(label) for label in old_order]).tolist() # assert sorted(self.data.columns.intersection(names)) == sorted(names) data = kwargs["data"].copy() data[DATA] = data[DATA].map(format_label) updates["data"] = data # Don't write updates until we're done; allows caller to skip if needed. kwargs.update(updates)
def _make_df(run_results: ResultsDict | pd.DataFrame) -> pd.DataFrame: from rics.performance import to_dataframe df = to_dataframe(run_results) if isinstance(run_results, dict) else run_results.copy() as_category = [DATA, FUNC] df[as_category] = df[as_category].astype("category") return df def _is_data_x(x_col: Literal["candidate", "data"] | None, *, df: pd.DataFrame) -> bool: if x_col is None: n_data, n_func = df[[DATA, FUNC]].nunique() return n_data > n_func # type: ignore[no-any-return] else: return x_col.lower().startswith("d") def _resolve_y(unit: Unit | None, *, df: pd.DataFrame) -> str: if unit is None: return _compute_nice_y(df) if unit == "us": unit = "μs" y = f"Time [{unit}]" if y not in df: msg = f"Bad {unit=}; column '{y}' not present in data." raise TypeError(msg) return y def _compute_nice_y(df: pd.DataFrame) -> str: """Pick the unit with the most "human" scale; whole numbers around one hundred.""" from numpy import log10 columns = [c for c in df.columns if c.startswith("Time [")] means = df.groupby([DATA, FUNC], observed=True)[columns].mean() residuals = log10(means) - 2 avg_residual_by_time_column = residuals.mean(axis="index") y = avg_residual_by_time_column.abs().idxmin() assert isinstance(y, str) # noqa: S101 return y