from collections.abc import Hashable, Iterable, Mapping
from dataclasses import dataclass, field, fields
from typing import Any, ClassVar, Literal, Self
import numpy as np
import pandas as pd
from ...collections.dicts import compute_if_absent
from ..types import ResultsDict
from .types import Candidate, FuncOrData, Kind, TestData, Unit
FUNC: Candidate = "Candidate"
DATA: TestData = "Test data"
[docs]
@dataclass(frozen=True, kw_only=True)
class CatplotParams:
data: pd.DataFrame = field(repr=False)
x: FuncOrData
y: str
hue: FuncOrData
kind: Kind
names: list[str] = field(metadata={"skip": True})
user_kwargs: Mapping[str, Any] = field(metadata={"skip": True})
DEFAULTS: ClassVar[dict[str, Any]] = {"errorbar": "sd", "estimator": "min", "aspect": 2}
DEFAULTS_BY_KIND: ClassVar[dict[Kind, dict[str, Any]]] = {
"bar": {"capsize": 0.2},
"point": {"capsize": 0.2},
}
COMPUTED_DEFAULTS: ClassVar[tuple[str, ...]] = ("order", "hue_order", "log_scale")
[docs]
@classmethod
def reserved_keys(cls) -> set[str]:
return {f.name for f in fields(cls) if not f.metadata.get("skip")}
def __post_init__(self) -> None:
if keys := self.reserved_keys().intersection(self.user_kwargs):
msg = f"Bad `kwargs`: {keys=} are reserved."
raise ValueError(msg)
[docs]
@classmethod
def make(
cls,
run_results: ResultsDict | pd.DataFrame,
*,
x: Literal["candidate", "data"] | None = None,
unit: Unit | None = None,
kind: Kind = "bar",
names: Iterable[str] = (),
**kwargs: Any,
) -> Self:
"""Create instance from run results."""
df = _make_df(run_results)
# MyPy thinks these are both string
x_col, hue_col = (DATA, FUNC) if _is_data_x(x, df=df) else (FUNC, DATA)
return cls(
data=df,
x=x_col, # type: ignore[arg-type]
y=_resolve_y(unit, df=df),
hue=hue_col, # type: ignore[arg-type]
kind=kind,
names=list(names),
user_kwargs=dict(kwargs),
)
[docs]
def to_kwargs(self) -> dict[str, Any]:
"""Convert to :func:`seaborn.catplot` keyword arguments."""
kwargs = dict(self.user_kwargs)
for key in self.reserved_keys():
kwargs[key] = getattr(self, key)
for key, default in self.DEFAULTS.items():
kwargs.setdefault(key, default)
for key, default in self.DEFAULTS_BY_KIND.get(self.kind, {}).items():
kwargs.setdefault(key, default)
for key in self.COMPUTED_DEFAULTS:
compute_if_absent(kwargs, key, self._compute_default)
if self.names:
self._handle_row_col(kwargs)
return kwargs
@property
def want_log_scale_hack(self) -> bool:
return self.kind == "bar"
user_log_scale = self.user_kwargs.get("log_scale")
if not user_log_scale:
return True # No explicit user choice - apply
# TODO(me): bool, number, or pair of bools or numbers
# and isinstance(self.user_kwargs.get("log_scale"), tuple)
def _compute_default(self, key: str) -> Any:
df = self.data
match key:
case "order":
return df[self.x].unique().tolist()
case "hue_order":
return df[self.hue].unique().tolist()
case "log_scale":
column = next(filter(lambda c: c.startswith("Time ["), df.columns))
means = df.groupby([DATA, FUNC], observed=True)[column].mean()
return (False, True) if means.max() / means.min() > 20 else None # noqa: PLR2004
raise KeyError(f"Bad {key=}. Possible choices={self.COMPUTED_DEFAULTS}.")
def _handle_row_col(self, kwargs: dict[str, Any]) -> None:
names = self.names
updates = {}
# We could print something like f"{label}: [row]", but this could add a lot of additional text to the legend
# or x-axis. I think this is a better way to do it.
hidden = {names.index(name) for key in ("row", "col") if (name := kwargs.get(key)) in names}
formatters = {i: f"{name} = {{}}" for i, name in enumerate(names) if i not in hidden}
def format_label(label: tuple[Hashable]) -> str:
if not isinstance(label, tuple) or len(label) != len(names):
raise TypeError(f"Cannot format {label=} using {names=}.")
g = (template.format(lv) for i, lv in enumerate(label) if (template := formatters.get(i)))
return " | ".join(g)
order_key = "hue_order" if self.hue == DATA else "order"
if old_order := kwargs.get(order_key):
updates[order_key] = np.unique([format_label(label) for label in old_order]).tolist()
# assert sorted(self.data.columns.intersection(names)) == sorted(names)
data = kwargs["data"].copy()
data[DATA] = data[DATA].map(format_label)
updates["data"] = data
# Don't write updates until we're done; allows caller to skip if needed.
kwargs.update(updates)
def _make_df(run_results: ResultsDict | pd.DataFrame) -> pd.DataFrame:
from rics.performance import to_dataframe
df = to_dataframe(run_results) if isinstance(run_results, dict) else run_results.copy()
as_category = [DATA, FUNC]
df[as_category] = df[as_category].astype("category")
return df
def _is_data_x(x_col: Literal["candidate", "data"] | None, *, df: pd.DataFrame) -> bool:
if x_col is None:
n_data, n_func = df[[DATA, FUNC]].nunique()
return n_data > n_func # type: ignore[no-any-return]
else:
return x_col.lower().startswith("d")
def _resolve_y(unit: Unit | None, *, df: pd.DataFrame) -> str:
if unit is None:
return _compute_nice_y(df)
if unit == "us":
unit = "μs"
y = f"Time [{unit}]"
if y not in df:
msg = f"Bad {unit=}; column '{y}' not present in data."
raise TypeError(msg)
return y
def _compute_nice_y(df: pd.DataFrame) -> str:
"""Pick the unit with the most "human" scale; whole numbers around one hundred."""
from numpy import log10
columns = [c for c in df.columns if c.startswith("Time [")]
means = df.groupby([DATA, FUNC], observed=True)[columns].mean()
residuals = log10(means) - 2
avg_residual_by_time_column = residuals.mean(axis="index")
y = avg_residual_by_time_column.abs().idxmin()
assert isinstance(y, str) # noqa: S101
return y