Skip to content

Commit

Permalink
Updated render_shapes doc (#199)
Browse files Browse the repository at this point in the history
* Updated render_shapes doc

* Updated render_shapes

* halfway points

* Updated type hints and handling

* Fixed color behaviour

* Changed source of logger; fixed type in test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed typos in documentation; minor correction to type checks

* Fixed test

* Added tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
timtreis and pre-commit-ci[bot] authored Jan 17, 2024
1 parent adb6bb8 commit a1788c8
Show file tree
Hide file tree
Showing 8 changed files with 633 additions and 214 deletions.
717 changes: 560 additions & 157 deletions src/spatialdata_plot/pl/basic.py

Large diffs are not rendered by default.

92 changes: 49 additions & 43 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _render_shapes(

for e in elements:
shapes = sdata.shapes[e]
n_shapes = sum([len(s) for s in shapes])
n_shapes = sum(len(s) for s in shapes)

if sdata.table is None:
table = AnnData(None, obs=pd.DataFrame(index=pd.Index(np.arange(n_shapes), dtype=str)))
Expand All @@ -94,11 +94,11 @@ def _render_shapes(
sdata=sdata_filt,
element=sdata_filt.shapes[e],
element_name=e,
value_to_plot=render_params.color,
value_to_plot=render_params.col_for_color,
layer=render_params.layer,
groups=render_params.groups,
palette=render_params.palette,
na_color=render_params.cmap_params.na_color,
na_color=render_params.color or render_params.cmap_params.na_color,
alpha=render_params.fill_alpha,
cmap_params=render_params.cmap_params,
)
Expand Down Expand Up @@ -162,14 +162,18 @@ def _render_shapes(
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
):
# necessary in case different shapes elements are annotated with one table
if color_source_vector is not None:
if color_source_vector is not None and render_params.col_for_color is not None:
color_source_vector = color_source_vector.remove_unused_categories()

# False if user specified color-like with 'color' parameter
colorbar = False if render_params.col_for_color is None else legend_params.colorbar

_ = _decorate_axs(
ax=ax,
cax=cax,
fig_params=fig_params,
adata=table,
value_to_plot=render_params.color,
value_to_plot=render_params.col_for_color,
color_source_vector=color_source_vector,
palette=palette,
alpha=render_params.fill_alpha,
Expand All @@ -179,7 +183,7 @@ def _render_shapes(
legend_loc=legend_params.legend_loc,
legend_fontoutline=legend_params.legend_fontoutline,
na_in_legend=legend_params.na_in_legend,
colorbar=legend_params.colorbar,
colorbar=colorbar,
scalebar_dx=scalebar_params.scalebar_dx,
scalebar_units=scalebar_params.scalebar_units,
)
Expand All @@ -194,12 +198,6 @@ def _render_points(
scalebar_params: ScalebarParams,
legend_params: LegendParams,
) -> None:
if render_params.groups is not None:
if isinstance(render_params.groups, str):
render_params.groups = [render_params.groups]
if not all(isinstance(g, str) for g in render_params.groups):
raise TypeError("All groups must be strings.")

elements = render_params.elements

sdata_filt = sdata.filter_by_coordinate_system(
Expand All @@ -214,43 +212,56 @@ def _render_points(

for e in elements:
points = sdata.points[e]
col_for_color = render_params.col_for_color

coords = ["x", "y"]
if render_params.color is not None:
color = [render_params.color] if isinstance(render_params.color, str) else render_params.color
coords.extend(color)
if col_for_color is not None:
if col_for_color not in points.columns:
# no error in case there are multiple elements, but onyl some have color key
msg = f"Color key '{col_for_color}' for element '{e}' not been found, using default colors."
logger.warning(msg)
else:
coords += [col_for_color]

points = points[coords].compute()
if render_params.groups is not None:
points = points[points[color].isin(render_params.groups).values]
points[color[0]] = points[color[0]].cat.set_categories(render_params.groups)
points = dask.dataframe.from_pandas(points, npartitions=1)
sdata_filt.points[e] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"})

point_df = points[coords].compute()
if render_params.groups is not None and col_for_color is not None:
points = points[points[col_for_color].isin(render_params.groups)]

# we construct an anndata to hack the plotting functions
adata = AnnData(
X=point_df[["x", "y"]].values, obs=point_df[coords].reset_index(), dtype=point_df[["x", "y"]].values.dtype
X=points[["x", "y"]].values, obs=points[coords].reset_index(), dtype=points[["x", "y"]].values.dtype
)
if render_params.color is not None:
cols = sc.get.obs_df(adata, render_params.color)

# Convert back to dask dataframe to modify sdata
points = dask.dataframe.from_pandas(points, npartitions=1)
sdata_filt.points[e] = PointsModel.parse(points, coordinates={"x": "x", "y": "y"})

if render_params.col_for_color is not None:
cols = sc.get.obs_df(adata, render_params.col_for_color)
# maybe set color based on type
if is_categorical_dtype(cols):
_maybe_set_colors(
source=adata,
target=adata,
key=render_params.color,
key=render_params.col_for_color,
palette=render_params.palette,
)

# when user specified a single color, we overwrite na with it
default_color = (
render_params.color
if render_params.col_for_color is None and render_params.color is not None
else render_params.cmap_params.na_color
)

color_source_vector, color_vector, _ = _set_color_source_vec(
sdata=sdata_filt,
element=points,
element_name=e,
value_to_plot=render_params.color,
value_to_plot=render_params.col_for_color,
groups=render_params.groups,
palette=render_params.palette,
na_color=render_params.cmap_params.na_color,
na_color=default_color,
alpha=render_params.alpha,
cmap_params=render_params.cmap_params,
)
Expand Down Expand Up @@ -278,9 +289,7 @@ def _render_points(
)
cax = ax.add_collection(_cax)

if not (
len(set(color_vector)) == 1 and list(set(color_vector))[0] == to_hex(render_params.cmap_params.na_color)
):
if len(set(color_vector)) != 1 or list(set(color_vector))[0] != to_hex(render_params.cmap_params.na_color):
if color_source_vector is None:
palette = ListedColormap(dict.fromkeys(color_vector))
else:
Expand All @@ -291,7 +300,7 @@ def _render_points(
cax=cax,
fig_params=fig_params,
adata=adata,
value_to_plot=render_params.color,
value_to_plot=render_params.col_for_color,
color_source_vector=color_source_vector,
palette=palette,
alpha=render_params.alpha,
Expand Down Expand Up @@ -629,8 +638,8 @@ def _render_labels(
_cax = ax.imshow(
labels_infill,
rasterized=True,
cmap=render_params.cmap_params.cmap if not categorical else None,
norm=render_params.cmap_params.norm if not categorical else None,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.fill_alpha,
origin="lower",
)
Expand All @@ -652,14 +661,11 @@ def _render_labels(
_cax = ax.imshow(
labels_contour,
rasterized=True,
cmap=render_params.cmap_params.cmap if not categorical else None,
norm=render_params.cmap_params.norm if not categorical else None,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.outline_alpha,
origin="lower",
)
_cax.set_transform(trans_data)
cax = ax.add_image(_cax)

else:
# Default: no alpha, contour = infill
label = _map_color_seg(
Expand All @@ -676,13 +682,13 @@ def _render_labels(
_cax = ax.imshow(
label,
rasterized=True,
cmap=render_params.cmap_params.cmap if not categorical else None,
norm=render_params.cmap_params.norm if not categorical else None,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.fill_alpha,
origin="lower",
)
_cax.set_transform(trans_data)
cax = ax.add_image(_cax)
_cax.set_transform(trans_data)
cax = ax.add_image(_cax)

_ = _decorate_axs(
ax=ax,
Expand Down
2 changes: 2 additions & 0 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class ShapesRenderParams:
outline_params: OutlineParams
elements: str | Sequence[str] | None = None
color: str | None = None
col_for_color: str | None = None
groups: str | Sequence[str] | None = None
contour_px: int | None = None
layer: str | None = None
Expand All @@ -89,6 +90,7 @@ class PointsRenderParams:
cmap_params: CmapParams
elements: str | Sequence[str] | None = None
color: str | None = None
col_for_color: str | None = None
groups: str | Sequence[str] | None = None
palette: ListedColormap | str | None = None
alpha: float = 1.0
Expand Down
30 changes: 16 additions & 14 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@
from spatial_image import SpatialImage
from spatialdata._core.operations.rasterize import rasterize
from spatialdata._core.query.relational_query import _locate_value, get_values
from spatialdata._logging import logger as logging
from spatialdata._types import ArrayLike
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement

from spatialdata_plot._logging import logger
from spatialdata_plot.pl.render_params import (
CmapParams,
FigParams,
Expand Down Expand Up @@ -379,7 +379,7 @@ def _set_outline(
if outline_width == 0.0:
outline = False
if outline_width < 0.0:
logging.warning(f"Negative line widths are not allowed, changing {outline_width} to {(-1)*outline_width}")
logger.warning(f"Negative line widths are not allowed, changing {outline_width} to {(-1)*outline_width}")
outline_width *= -1

# the default black and white colors can be changed using the contour_config parameter
Expand Down Expand Up @@ -561,7 +561,7 @@ def _get_colors_for_categorical_obs(
palette = default_102
else:
palette = ["grey" for _ in range(len_cat)]
logging.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.")
logger.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.")
else:
# raise error when user didn't provide the right number of colors in palette
if isinstance(palette, list) and len(palette) != len(categories):
Expand Down Expand Up @@ -623,7 +623,7 @@ def _set_color_source_vec(
# numerical case, return early
if not is_categorical_dtype(color_source_vector):
if palette is not None:
logging.warning(
logger.warning(
"Ignoring categorical palette which is given for a continuous variable. "
"Consider using `cmap` to pass a ColorMap."
)
Expand Down Expand Up @@ -651,7 +651,7 @@ def _set_color_source_vec(

return color_source_vector, color_vector, True

logging.warning(f"Color key '{value_to_plot}' for element '{element_name}' not been found, using default colors.")
logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not been found, using default colors.")
color = np.full(sdata.table.n_obs, to_hex(na_color))
return color, color, False

Expand Down Expand Up @@ -723,7 +723,7 @@ def _get_palette(
)
return {cat: to_hex(to_rgba(col)[:3]) for cat, col in zip(categories, palette)}
except KeyError as e:
logging.warning(e)
logger.warning(e)
return None

len_cat = len(categories)
Expand All @@ -737,7 +737,7 @@ def _get_palette(
palette = default_102
else:
palette = ["grey" for _ in range(len_cat)]
logging.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.")
logger.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.")
return {cat: to_hex(to_rgba(col)[:3]) for cat, col in zip(categories, palette[:len_cat])}

if isinstance(palette, str):
Expand Down Expand Up @@ -904,9 +904,9 @@ def save_fig(fig: Figure, path: str | Path, make_dir: bool = True, ext: str = "p
try:
path.parent.mkdir(parents=True, exist_ok=True)
except OSError as e:
logging.debug(f"Unable to create directory `{path.parent}`. Reason: `{e}`")
logger.debug(f"Unable to create directory `{path.parent}`. Reason: `{e}`")

logging.debug(f"Saving figure to `{path!r}`")
logger.debug(f"Saving figure to `{path!r}`")

kwargs.setdefault("bbox_inches", "tight")
kwargs.setdefault("transparent", True)
Expand Down Expand Up @@ -1070,13 +1070,13 @@ def _mpl_ax_contains_elements(ax: Axes) -> bool:

def _get_valid_cs(
sdata: sd.SpatialData,
coordinate_systems: Sequence[str],
coordinate_systems: list[str],
render_images: bool,
render_labels: bool,
render_points: bool,
render_shapes: bool,
elements: list[str],
) -> Sequence[str]:
) -> list[str]:
"""Get names of the valid coordinate systems.
Valid cs are cs that contain elements to be rendered:
Expand All @@ -1090,8 +1090,10 @@ def _get_valid_cs(
cs_mapping = _get_coordinate_system_mapping(sdata)
valid_cs = []
for cs in coordinate_systems:
if (len(elements) > 0 and any(e in elements for e in cs_mapping[cs])) or (
len(elements) == 0
if (
elements
and any(e in elements for e in cs_mapping[cs])
or not elements
and (
(len(sdata.images.keys()) > 0 and render_images)
or (len(sdata.labels.keys()) > 0 and render_labels)
Expand All @@ -1101,7 +1103,7 @@ def _get_valid_cs(
): # not nice, but ruff wants it (SIM114)
valid_cs.append(cs)
else:
logging.info(f"Dropping coordinate system '{cs}' since it doesn't have relevant elements.")
logger.info(f"Dropping coordinate system '{cs}' since it doesn't have relevant elements.")
return valid_cs


Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@ def test_plot_can_stack_render_points(self, sdata_blobs: SpatialData):
.pl.render_points(elements="blobs_points", na_color="blue", size=10)
.pl.show()
)

def test_plot_color_recognises_actual_color_as_color(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_points(elements="blobs_points", color="red").pl.show()
3 changes: 3 additions & 0 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,6 @@ def test_plot_can_stack_render_shapes(self, sdata_blobs: SpatialData):
.pl.render_shapes(elements="blobs_polygons", na_color="blue", fill_alpha=0.5)
.pl.show()
)

def test_plot_color_recognises_actual_color_as_color(self, sdata_blobs: SpatialData):
(sdata_blobs.pl.render_shapes(elements="blobs_circles", color="red").pl.show())

0 comments on commit a1788c8

Please sign in to comment.