diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 7ff110d0..08b2094c 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -17,7 +17,7 @@ from matplotlib.cm import ScalarMappable from matplotlib.colors import ListedColormap, Normalize from scanpy._settings import settings as sc_settings -from spatialdata import get_extent +from spatialdata import get_extent, join_spatialelement_table from spatialdata.models import PointsModel, ShapesModel, get_table_keys from spatialdata.transformations import get_transformation, set_transformation from spatialdata.transformations.transformations import Identity @@ -76,13 +76,18 @@ def _render_shapes( filter_tables=bool(render_params.table_name), ) - shapes = sdata[element] - if (table_name := render_params.table_name) is None: table = None + shapes = sdata_filt[element] else: - _, region_key, _ = get_table_keys(sdata[table_name]) - table = sdata[table_name][sdata[table_name].obs[region_key].isin([element])] + element_dict, joined_table = join_spatialelement_table( + sdata, spatial_element_names=element, table_name=table_name, how="inner" + ) + sdata_filt[element] = shapes = element_dict[element] + joined_table.uns["spatialdata_attrs"]["region"] = ( + joined_table.obs[joined_table.uns["spatialdata_attrs"]["region_key"]].unique().tolist() + ) + sdata_filt[table_name] = table = joined_table if ( col_for_color is not None diff --git a/tests/_images/Shapes_can_color_two_queried_shapes_elements_by_annotation.png b/tests/_images/Shapes_can_color_two_queried_shapes_elements_by_annotation.png index 3733eff1..fd8ddc59 100644 Binary files a/tests/_images/Shapes_can_color_two_queried_shapes_elements_by_annotation.png and b/tests/_images/Shapes_can_color_two_queried_shapes_elements_by_annotation.png differ diff --git a/tests/_images/Shapes_can_do_non_matching_table.png b/tests/_images/Shapes_can_do_non_matching_table.png new file mode 100644 index 00000000..6f6f6214 Binary files /dev/null and b/tests/_images/Shapes_can_do_non_matching_table.png differ diff --git a/tests/_images/Shapes_can_plot_queried_with_annotation_despite_random_shuffling.png b/tests/_images/Shapes_can_plot_queried_with_annotation_despite_random_shuffling.png index df0f43d0..41205c15 100644 Binary files a/tests/_images/Shapes_can_plot_queried_with_annotation_despite_random_shuffling.png and b/tests/_images/Shapes_can_plot_queried_with_annotation_despite_random_shuffling.png differ diff --git a/tests/_images/Shapes_can_plot_with_annotation_despite_random_shuffling.png b/tests/_images/Shapes_can_plot_with_annotation_despite_random_shuffling.png index c36b0b9c..ac33e269 100644 Binary files a/tests/_images/Shapes_can_plot_with_annotation_despite_random_shuffling.png and b/tests/_images/Shapes_can_plot_with_annotation_despite_random_shuffling.png differ diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index d683189a..dcc24fac 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -169,7 +169,6 @@ def test_plot_can_plot_shapes_after_spatial_query(self, sdata_blobs: SpatialData cropped_blob.pl.render_shapes().pl.show() def test_plot_can_plot_with_annotation_despite_random_shuffling(self, sdata_blobs: SpatialData): - new_table = sdata_blobs["table"].copy() sdata_blobs["table"].obs["region"] = "blobs_circles" new_table = sdata_blobs["table"][:5] new_table.uns["spatialdata_attrs"]["region"] = "blobs_circles" @@ -447,3 +446,12 @@ def test_plot_datashader_can_transform_circles(self, sdata_blobs: SpatialData): _set_transformations(sdata_blobs["blobs_circles"], {"global": seq}) sdata_blobs.pl.render_shapes("blobs_circles", method="datashader", outline_alpha=1.0).pl.show() + + def test_plot_can_do_non_matching_table(self, sdata_blobs: SpatialData): + table_shapes = sdata_blobs["table"][:3].copy() + table_shapes.obs.instance_id = list(range(3)) + table_shapes.obs["region"] = "blobs_circles" + table_shapes.uns["spatialdata_attrs"]["region"] = "blobs_circles" + sdata_blobs["new_table"] = table_shapes + + sdata_blobs.pl.render_shapes("blobs_circles", color="instance_id").pl.show()