"""Plotting utility methods."""
import typing as _t
from matplotlib.axis import Axis as _Axis
from matplotlib.axis import XAxis as _XAxis
from matplotlib.ticker import FuncFormatter as _FuncFormatter
from matplotlib.ticker import IndexLocator as _IndexLocator
ERROR_BAR_CAPSIZE: float = 0.1
HalfRep = _t.Literal["fraction", "decimal", "frac", "dec", "f", "d"]
[docs]
class HasXAxis(_t.Protocol):
"""Protocol class indicating something that as an X-axis."""
xaxis: _XAxis
"""X-Axis attribute."""
# Doesn't play nice with all plot kinds
# sns.catplot = functools.partial(sns.catplot, capsize=ERROR_BAR_CAPSIZE, height=5)
[docs]
def pi_ticks(ax: _Axis | HasXAxis, half_rep: HalfRep | None = None) -> None:
"""Decorate an axis by setting the labels to multiples of pi.
.. image:: ../_images/pi_ticks.png
.. list-table:: Options for the `half_rep` argument.
:header-rows: 1
* - Value
- Interpretation
- Example output
* - ``None``
- Show integer multiples only.
- `0, π, 2π, 3π..`
* - `'f'` or `'fraction'`
- Halves of `Ï€` use fractional representation.
- `0/2Ï€, 1/2Ï€, 2/2Ï€, 3/2Ï€..`
* - `'d'` or `'decimal'`
- Halves of `Ï€` use decimal representation.
- `0.0Ï€, 0.5Ï€, 1.0Ï€, 1.5Ï€..`
Args:
ax: An axis to decorate, or an object with an `xaxis` attribute.
half_rep: Controls how fractions of `Ï€` are represented on the x-axis.
"""
axis: _Axis = ax.xaxis if hasattr(ax, "xaxis") else ax
helper = _PiTickHelper(half_rep, start=axis.get_data_interval()[0])
axis.set_major_locator(helper.locator)
axis.set_major_formatter(helper.formatter)
class _PiTickHelper:
PI: float = 3.14159265359
HALF_PI: float = PI / 2
def __init__(
self,
half_rep: HalfRep | None,
*,
start: float,
) -> None:
self.half = self._parse_half_rep(half_rep)
r = self.HALF_PI if half_rep else self.PI
self.offset = r * round(start / r) - start # Offset from nearest (half) multiple of pi.
self.locator = _IndexLocator(base=self.HALF_PI if self.half else self.PI, offset=self.offset)
self.formatter = _FuncFormatter(self._format)
@staticmethod
def _parse_half_rep(
half_rep: str | None,
) -> _t.Literal["frac", "dec"] | None:
if half_rep is None:
return None
parsed_half = half_rep.lower()
if parsed_half.startswith("f"):
return "frac"
if parsed_half.startswith("d"):
return "dec"
raise TypeError(f"Argument {half_rep=} not in ('[f]raction', '[d]ecimal', None).")
def _format(self, x: float, _pos: int) -> str:
n = round(x / self.PI, 1)
if self.half is None:
if x == 0:
return "0"
if x == self.PI:
return "Ï€"
return f"{int(n)}Ï€"
if self.half == "dec":
return f"{n}Ï€"
return f"{int(n * 2)}/2Ï€"