diff --git a/pyproject.toml b/pyproject.toml index ad02f82..c3cc840 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ect" -version = "1.2.0" +version = "1.2.1" authors = [ { name="Liz Munch", email="muncheli@msu.edu" }, ] diff --git a/src/ect/results.py b/src/ect/results.py index 9ed0821..f7166dc 100644 --- a/src/ect/results.py +++ b/src/ect/results.py @@ -27,8 +27,15 @@ def __array_finalize__(self, obj): self.directions = getattr(obj, "directions", None) self.thresholds = getattr(obj, "thresholds", None) - def plot(self, ax=None): - """Plot ECT matrix with proper handling for both 2D and 3D""" + def plot(self, ax=None, *, radial=False, **kwargs): + """Plot ECT matrix with proper handling for both 2D and 3D. + + Set radial=True to render a polar visualization (2D only). Any extra + keyword arguments are forwarded to the radial renderer. + """ + if radial: + return self._plot_radial(ax=ax, **kwargs) + ax = ax or plt.gca() if self.thresholds is None: @@ -107,6 +114,46 @@ def smooth(self): # create new ECTResult with float type return ECTResult(sect.astype(np.float64), self.directions, self.thresholds) + # Internal plotting utilities + def _ensure_2d(self): + if self.directions is None or self.directions.dim != 2: + raise ValueError("This visualization is only supported for 2D ECT results") + + def _theta_threshold_mesh(self): + thetas = self.directions.thetas + thresholds = self.thresholds + THETA, R = np.meshgrid(thetas, thresholds) + return THETA, R + + def _configure_polar_axes( + self, ax, rmin=0.0, rmax=None, theta_zero="N", theta_dir=-1 + ): + ax.set_theta_zero_location(theta_zero) + ax.set_theta_direction(theta_dir) + if rmax is None: + rmax = float(np.max(self.thresholds)) + ax.set_ylim(float(rmin), float(rmax)) + return ax + + def _scale_overlay_radii(self, points, rmin=0.0, rmax=None, fit_to_thresholds=True): + x = points[:, 0] + y = points[:, 1] + r = np.sqrt(x**2 + y**2) + theta = np.arctan2(y, x) + + if rmax is None: + rmax = float(np.max(self.thresholds)) + + if not fit_to_thresholds: + return theta, r + + max_r_points = float(np.max(r)) if r.size else 0.0 + if max_r_points > 0.0: + scaled_r = (r / max_r_points) * (rmax - float(rmin)) + float(rmin) + else: + scaled_r = r + return theta, scaled_r + def _plot_ecc(self, theta): """Plot the Euler Characteristic Curve for a specific direction""" plt.step(self.thresholds, self.T, label="ECC") @@ -115,6 +162,120 @@ def _plot_ecc(self, theta): plt.xlabel("$a$") plt.ylabel(r"$\chi(K_a)$") + def _plot_radial( + self, + ax=None, + title=None, + cmap="viridis", + *, + rmin=0.0, + rmax=None, + colorbar=True, + overlay=None, + overlay_kwargs=None, + **kwargs, + ): + """ + Plot ECT matrix in polar coordinates (radial plot). + + Args: + ax: matplotlib axes object. If None, creates a new polar subplot + title: optional string for plot title + cmap: colormap for the plot (default: 'viridis') + rmin: minimum radius for the plot (default: 0.0) + rmax: maximum radius for the plot (default: None) + colorbar: whether to show the colorbar (default: True) + overlay: points to overlay on the plot (default: None) + + **kwargs: additional keyword arguments passed to pcolormesh + + Returns: + matplotlib.axes.Axes: The axes object used for plotting + """ + self._ensure_2d() + + if ax is None: + fig, ax = plt.subplots( + subplot_kw=dict(projection="polar"), figsize=(10, 10) + ) + + THETA, R = self._theta_threshold_mesh() + + im = ax.pcolormesh(THETA, R, self.T, cmap=cmap, **kwargs) + + self._configure_polar_axes(ax, rmin=rmin, rmax=rmax) + + if title: + ax.set_title(title) + + if colorbar: + plt.colorbar(im, ax=ax, label="ECT Value") + + if overlay is not None: + overlay_kwargs = overlay_kwargs or {} + theta, scaled_r = self._scale_overlay_radii( + overlay, rmin=rmin, rmax=rmax, fit_to_thresholds=True + ) + ax.plot( + theta, + scaled_r, + "-", + color=overlay_kwargs.get("color", "black"), + linewidth=overlay_kwargs.get("linewidth", 2), + alpha=overlay_kwargs.get("alpha", 0.5), + ) + + return ax + + def _overlay_points( + self, + points, + ax=None, + color="black", + linewidth=2, + alpha=0.5, + *, + rmin=0.0, + rmax=None, + fit_to_thresholds=True, + **kwargs, + ): + """ + Overlay original points on a radial ECT plot. + + Args: + points: numpy array of shape (N, 2) containing the original points + ax: matplotlib polar axes object. If None, uses current axes + color: color for the overlay line (default: 'white') + linewidth: line width for the overlay (default: 2) + alpha: transparency for the overlay (default: 0.5) + **kwargs: additional keyword arguments passed to plot + + Returns: + matplotlib.axes.Axes: The axes object used for plotting + """ + if ax is None: + ax = plt.gca() + + if not hasattr(ax, "name") or ax.name != "polar": + raise ValueError("overlay_points requires a polar axes object") + + theta, scaled_r = self._scale_overlay_radii( + points, rmin=rmin, rmax=rmax, fit_to_thresholds=fit_to_thresholds + ) + + ax.plot( + theta, + scaled_r, + "-", + color=color, + linewidth=linewidth, + alpha=alpha, + **kwargs, + ) + + return ax + def dist( self, other: Union["ECTResult", List["ECTResult"]],