diff --git a/README.md b/README.md index ff1b1e4..a5b5cdf 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ -# Swarm Visualizer -This is a plotting visualizer packaged developed by UT Austin Swarm Lab. Please use these plotting package for all papers, plots in lab slides, etc. If you find errors or need a new utility, feel free to push new functions. +# Swarm Visualization +This is a plotting visualizer packaged developed by UT Austin Swarm Lab. Please use these plotting package for all papers, plots in lab slides, etc. If you find errors or need a new utility, feel free to create pull request for new features. ## Usage For example usage, see the code in the `tests` folder. If you want to see example plots, run `pytest` in your terminal. All example plots will be available in `tests/example_plots` @@ -14,6 +14,7 @@ Option 2: pip install -e . ``` +## Development If you are a developer: ``` python -m venv .venv @@ -21,13 +22,12 @@ source .venv/bin/activate python -m pip install --upgrade pip build python -m pip install --editable ."[dev, test]" ``` -Please make sure to write unit test for every method that you are developing. +Please make sure to write unit tests for every method that you are developing. -Here's example import +Here's an example code to import functions from the package: ```python from swarm_visualizer import plot_grouped_violinplot from swarm_visualizer import plot_grouped_boxplot from swarm_visualizer import plot_grid ``` - diff --git a/swarm_visualizer/__init__.py b/swarm_visualizer/__init__.py index 220e099..237023d 100644 --- a/swarm_visualizer/__init__.py +++ b/swarm_visualizer/__init__.py @@ -1,13 +1,25 @@ -from .barplot import plot_grouped_barplot, plot_stacked_barplot, plot_sns_grouped_barplot +from .barplot import ( + plot_grouped_barplot, + plot_stacked_barplot, + plot_sns_grouped_barplot, +) from .boxplot import plot_grouped_boxplot, plot_paired_boxplot from .lineplot import plot_basic_lineplot, plot_overlaid_lineplot from .violinplot import plot_grouped_violinplot, plot_paired_violinplot from .scatterplot import plot_basic_scatterplot, plot_scatter_pdf_plot from .gridplot import plot_grid -__all__ = [ "plot_grouped_barplot", "plot_stacked_barplot", "plot_sns_grouped_barplot", - "plot_grouped_boxplot", "plot_paired_boxplot", - "plot_basic_lineplot", "plot_overlaid_lineplot", - "plot_grouped_violinplot", "plot_paired_violinplot", - "plot_basic_scatterplot", "plot_scatter_pdf_plot", - "plot_grid" ] \ No newline at end of file +__all__ = [ + "plot_grouped_barplot", + "plot_stacked_barplot", + "plot_sns_grouped_barplot", + "plot_grouped_boxplot", + "plot_paired_boxplot", + "plot_basic_lineplot", + "plot_overlaid_lineplot", + "plot_grouped_violinplot", + "plot_paired_violinplot", + "plot_basic_scatterplot", + "plot_scatter_pdf_plot", + "plot_grid", +] diff --git a/swarm_visualizer/barplot.py b/swarm_visualizer/barplot.py index d6303ad..c417305 100644 --- a/swarm_visualizer/barplot.py +++ b/swarm_visualizer/barplot.py @@ -2,6 +2,7 @@ from swarm_visualizer.utility import set_axis_infos + def plot_grouped_barplot( ax, df=None, @@ -11,7 +12,7 @@ def plot_grouped_barplot( title_str=None, pal=None, y_label=None, - **kwargs + **kwargs, ) -> None: """Plots a grouped barplot. In this case, there are multiple y-var for each x-var. @@ -30,7 +31,13 @@ def plot_grouped_barplot( if pal: colors = [pal(i) for i in range(len(x_var))] df.plot( - kind="bar", stacked=False, ax=ax, x=x_var, y=y_var, colors=colors, **kwargs + kind="bar", + stacked=False, + ax=ax, + x=x_var, + y=y_var, + colors=colors, + **kwargs, ) ### set y label @@ -40,6 +47,7 @@ def plot_grouped_barplot( # Set axis infos set_axis_infos(ax, ylim=ylim, title_str=title_str) + def plot_sns_grouped_barplot( ax, df=None, @@ -50,7 +58,7 @@ def plot_sns_grouped_barplot( title_str=None, pal=None, y_label=None, - **kwargs + **kwargs, ) -> None: """Plots a grouped barplot with sns. hue specifies the group. @@ -81,7 +89,6 @@ def plot_sns_grouped_barplot( set_axis_infos(ax, ylim=ylim, title_str=title_str) - def plot_stacked_barplot( ax, df=None, @@ -91,7 +98,7 @@ def plot_stacked_barplot( title_str=None, pal=None, y_label=None, - **kwargs + **kwargs, ) -> None: """Plots a grouped barplot. @@ -110,7 +117,13 @@ def plot_stacked_barplot( if pal: colors = [pal(i) for i in range(len(x_var))] df.plot( - kind="bar", stacked=True, ax=ax, x=x_var, y=y_var, colors=colors, **kwargs + kind="bar", + stacked=True, + ax=ax, + x=x_var, + y=y_var, + colors=colors, + **kwargs, ) ### set y label diff --git a/swarm_visualizer/boxplot.py b/swarm_visualizer/boxplot.py index 1786b68..c167e1e 100644 --- a/swarm_visualizer/boxplot.py +++ b/swarm_visualizer/boxplot.py @@ -17,7 +17,7 @@ def plot_paired_boxplot( order_list=None, pal=None, hue=None, - **kwargs + **kwargs, ) -> None: """Plots a paired boxplot. @@ -37,11 +37,23 @@ def plot_paired_boxplot( # Plots a boxplot with order if order_list: sns.boxplot( - x=x_var, y=y_var, data=df, order=order_list, hue=hue, ax=ax, **kwargs + x=x_var, + y=y_var, + data=df, + order=order_list, + hue=hue, + ax=ax, + **kwargs, ) else: sns.boxplot( - x=x_var, y=y_var, data=df, order=order_list, hue=hue, ax=ax, **kwargs + x=x_var, + y=y_var, + data=df, + order=order_list, + hue=hue, + ax=ax, + **kwargs, ) # Plots a boxplot with palette @@ -55,7 +67,7 @@ def plot_paired_boxplot( palette=pal, hue=hue, ax=ax, - **kwargs + **kwargs, ) else: sns.boxplot( @@ -66,7 +78,7 @@ def plot_paired_boxplot( palette=pal, hue=hue, ax=ax, - **kwargs + **kwargs, ) # Set axis infos @@ -87,7 +99,7 @@ def plot_grouped_boxplot( title_str=None, order_list=None, pal=None, - **kwargs + **kwargs, ) -> None: """Plots a grouped boxplot. @@ -103,18 +115,34 @@ def plot_grouped_boxplot( """ if not pal: if order_list: - sns.boxplot(x=x_var, y=y_var, data=df, order=order_list, ax=ax, **kwargs) + sns.boxplot( + x=x_var, y=y_var, data=df, order=order_list, ax=ax, **kwargs + ) else: - sns.boxplot(x=x_var, y=y_var, data=df, order=order_list, ax=ax, **kwargs) + sns.boxplot( + x=x_var, y=y_var, data=df, order=order_list, ax=ax, **kwargs + ) if pal: if order_list: sns.boxplot( - x=x_var, y=y_var, data=df, order=order_list, palette=pal, ax=ax, **kwargs + x=x_var, + y=y_var, + data=df, + order=order_list, + palette=pal, + ax=ax, + **kwargs, ) else: sns.boxplot( - x=x_var, y=y_var, data=df, order=order_list, palette=pal, ax=ax, **kwargs + x=x_var, + y=y_var, + data=df, + order=order_list, + palette=pal, + ax=ax, + **kwargs, ) # Set axis infos diff --git a/swarm_visualizer/gridplot.py b/swarm_visualizer/gridplot.py index 5b18e22..302463f 100644 --- a/swarm_visualizer/gridplot.py +++ b/swarm_visualizer/gridplot.py @@ -10,7 +10,7 @@ def plot_grid( plot_file: str = None, lw: float = 3.0, xlabel: str = None, - **kwargs + **kwargs, ) -> None: """Plot grid of time series. diff --git a/swarm_visualizer/histogram.py b/swarm_visualizer/histogram.py index 73fad92..3f71e0a 100644 --- a/swarm_visualizer/histogram.py +++ b/swarm_visualizer/histogram.py @@ -5,7 +5,7 @@ def plot_pdf( - ax, data=None, xlabel: str = None, title_str: str = None, **kwargs + ax, data=None, xlabel: str = None, title_str: str = None, **kwargs ) -> None: """Plot PDF of a data. @@ -30,7 +30,7 @@ def plot_pdf( alpha=0.4, edgecolor=(1, 1, 1, 0.4), ax=ax, - **kwargs + **kwargs, ) # Set axis infos @@ -39,16 +39,16 @@ def plot_pdf( def plot_several_pdf( ax, - data_list: list[np.ndarray]=None, + data_list: list[np.ndarray] = None, xlabel: str = None, title_str: str = None, legend=None, ylabel: str = None, xlim=None, kde: bool = False, - bins = "auto", - binwidth = None, - **kwargs + bins="auto", + binwidth=None, + **kwargs, ) -> None: """Plot PDF of a data list. @@ -75,7 +75,7 @@ def plot_several_pdf( ax=ax, bins=bins, binwidth=binwidth, - **kwargs + **kwargs, ) # Set axis infos diff --git a/swarm_visualizer/lineplot.py b/swarm_visualizer/lineplot.py index adedf0c..0bf53ba 100644 --- a/swarm_visualizer/lineplot.py +++ b/swarm_visualizer/lineplot.py @@ -13,7 +13,7 @@ def plot_basic_lineplot( lw: float = 3.0, ylim=None, xlabel: str = "x", - **kwargs + **kwargs, ) -> None: """Basic lineplot. @@ -30,7 +30,12 @@ def plot_basic_lineplot( ax.plot(y, lw=lw) set_axis_infos( - ax, xlabel=xlabel, ylabel=ylabel, ylim=ylim, title_str=title_str, **kwargs + ax, + xlabel=xlabel, + ylabel=ylabel, + ylim=ylim, + title_str=title_str, + **kwargs, ) @@ -46,10 +51,10 @@ def plot_overlaid_lineplot( legend_present: bool = True, DEFAULT_MARKERSIZE: float = 15, delete_yticks: bool = False, - **kwargs + **kwargs, ) -> None: """Overlaid line plot. - + :param ax: axis to plot on :param normalized_dict: dictionary with values to plot :param title_str: title of the plot @@ -112,7 +117,7 @@ def plot_overlaid_lineplot( ms=DEFAULT_MARKERSIZE, color=color, zorder=zorder, - **kwargs + **kwargs, ) else: ax.plot( @@ -124,7 +129,7 @@ def plot_overlaid_lineplot( alpha=alpha, color=color, zorder=zorder, - **kwargs + **kwargs, ) # Plot without x-axis if x is not specified else: @@ -139,7 +144,7 @@ def plot_overlaid_lineplot( ms=DEFAULT_MARKERSIZE, color=color, zorder=zorder, - **kwargs + **kwargs, ) else: ax.plot( @@ -150,7 +155,7 @@ def plot_overlaid_lineplot( alpha=alpha, color=color, zorder=zorder, - **kwargs + **kwargs, ) i += 1 diff --git a/swarm_visualizer/processing/__init__.py b/swarm_visualizer/processing/__init__.py new file mode 100644 index 0000000..bab609e --- /dev/null +++ b/swarm_visualizer/processing/__init__.py @@ -0,0 +1,12 @@ +"""Package containing for processing data.""" + +from __future__ import annotations + +from .grouping import average_by_group, group_values_by_bound +from .sorting import sort_array_based_on_reference_array + +__all__ = [ + "group_values_by_bound", + "sort_array_based_on_reference_array", + "average_by_group", +] diff --git a/swarm_visualizer/processing/grouping.py b/swarm_visualizer/processing/grouping.py new file mode 100644 index 0000000..03ca9e3 --- /dev/null +++ b/swarm_visualizer/processing/grouping.py @@ -0,0 +1,68 @@ +"""Data Processing for grouping.""" + +from __future__ import annotations + +import numpy as np + + +def group_values_by_bound( + values: np.ndarray, + step: float = 0.1, + min_value: float = 0.0, + max_value: float = 8.0, + use_upper_bound: bool = True, +) -> np.ndarray: + """ + Group values by replacing each with the nearest upper or lower bound defined by range steps. + + Args: + values (list or np.array): List of values to be grouped. + step (float): Step size for range intervals. + min_value (float): Minimum value for the range intervals. + max_value (float): Maximum value for the range intervals. + use_upper_bound (bool): If True, group by upper bound, otherwise by lower bound. + + Returns: + np.array: Array of values grouped by the specified bound of their range. + """ + ranges = np.arange(min_value, max_value + step, step) + group_list = [] + + if use_upper_bound: + for value in values: + bound = next((x for x in ranges if x > value), None) + if bound is not None: + group_list.append(round(bound, 1)) + else: + group_list.append(max_value) + else: + for value in values: + bound = next((x for x in ranges[::-1] if x <= value), None) + if bound is not None: + group_list.append(round(bound, 1)) + else: + group_list.append(max_value) + + return np.array(group_list) + + +def average_by_group( + values: np.ndarray | list, group_keys: np.ndarray | list +) -> dict: + if isinstance(values, list): + values = np.array(values) + + if isinstance(group_keys, list): + group_keys = np.array(group_keys) + + # Get unique groups and their indices + unique_groups = np.unique(group_keys) + average_results = {} + + for group in unique_groups: + # Get the indices where group_keys equals the current group + indices = np.where(group_keys == group) + # Calculate the average for the current group + average_results[group] = np.mean(values[indices]) + + return average_results diff --git a/swarm_visualizer/processing/sorting.py b/swarm_visualizer/processing/sorting.py new file mode 100644 index 0000000..4f5f205 --- /dev/null +++ b/swarm_visualizer/processing/sorting.py @@ -0,0 +1,46 @@ +"""Data Processing for sorting.""" + +from __future__ import annotations + +import numpy as np + + +def sort_array_based_on_reference_array( + list_of_arrays: list[np.ndarray], + reference_array: np.ndarray, + order: str = "ascending", +) -> list[np.ndarray]: + """ + Sort multiple arrays based on the values of a single reference array, with an option to sort in ascending or descending order. + + Args: + - list_of_arrays (list[np.ndarray]): List of arrays to be sorted including the reference array. + - reference_array (np.ndarray): The reference array that dictates the sort order. + - order (str): Sorting order, 'ascending' or 'descending'. Default is 'ascending'. + + Returns: + - list[np.ndarray]: List of sorted arrays including the reference array. + """ + # Include reference_array in the list for sorting + list_of_arrays = [reference_array] + list_of_arrays + + if not all(len(arr) == len(reference_array) for arr in list_of_arrays): + raise ValueError( + "All arrays must have the same length as the reference array." + ) + + # Create tuples of indexes and reference array values + indexed_pairs = list(enumerate(reference_array)) + + # Sort indexed pairs based on reference values + indexed_pairs_sorted = sorted( + indexed_pairs, key=lambda x: x[1], reverse=(order == "descending") + ) + + # Reorder each array in the list according to the sorted indexes + sorted_arrays = [ + np.array([array[idx] for idx, _ in indexed_pairs_sorted]) + for array in list_of_arrays + ] + + return sorted_arrays diff --git a/swarm_visualizer/scatterplot.py b/swarm_visualizer/scatterplot.py index aa68bd1..a617a69 100644 --- a/swarm_visualizer/scatterplot.py +++ b/swarm_visualizer/scatterplot.py @@ -16,7 +16,7 @@ def plot_basic_scatterplot( xlim=None, ms: float = 4.0, color: str = "b", - **kwargs + **kwargs, ) -> None: """Basic scatter plot. @@ -64,10 +64,10 @@ def plot_scatter_pdf_plot( ylim=None, xlabel: str = "time", xlim=None, - **kwargs + **kwargs, ): """Scatter plot with the PDFs and saves the plot. - + :param x: x-axis data :param y: y-axis data :param title_str: title of the plot @@ -79,11 +79,7 @@ def plot_scatter_pdf_plot( :return: None. """ # Joint plot - fig = sns.jointplot( - x=x, - y=y, - **kwargs - ) + fig = sns.jointplot(x=x, y=y, **kwargs) # Set labels fig.set_axis_labels(xlabel, ylabel) diff --git a/swarm_visualizer/utility/__init__.py b/swarm_visualizer/utility/__init__.py index 37f61ce..625a5c3 100644 --- a/swarm_visualizer/utility/__init__.py +++ b/swarm_visualizer/utility/__init__.py @@ -1,7 +1,12 @@ -from .general_utils import set_axis_infos,set_plot_properties,save_fig +from .general_utils import set_axis_infos, set_plot_properties, save_fig from .statistics_utils import add_wilcoxon_value -from .legendplot_utils import create_seperate_legend +from .legendplot_utils import create_seperate_legend, create_colorbar -__all__ = [ "set_axis_infos", "set_plot_properties", "save_fig", - "add_wilcoxon_value", - "create_seperate_legend" ] \ No newline at end of file +__all__ = [ + "set_axis_infos", + "set_plot_properties", + "save_fig", + "add_wilcoxon_value", + "create_seperate_legend", + "create_colorbar", +] diff --git a/swarm_visualizer/utility/calculation_utils.py b/swarm_visualizer/utility/calculation_utils.py index ab03881..13d0cc8 100644 --- a/swarm_visualizer/utility/calculation_utils.py +++ b/swarm_visualizer/utility/calculation_utils.py @@ -4,7 +4,6 @@ """ - import matplotlib import numpy as np diff --git a/swarm_visualizer/utility/general_utils.py b/swarm_visualizer/utility/general_utils.py index 30c5913..f755970 100644 --- a/swarm_visualizer/utility/general_utils.py +++ b/swarm_visualizer/utility/general_utils.py @@ -13,7 +13,7 @@ def set_plot_properties( ytick_label_size: float = 14, markersize: float = 10, usetex: bool = True, - autolayout = True, + autolayout=True, ) -> None: """Sets plot properties. @@ -54,7 +54,9 @@ def set_plot_properties( sns.set_style(style="darkgrid") -def save_fig(fig, save_loc: str = None, dpi: int = 600, tight_layout: bool=True) -> None: +def save_fig( + fig, save_loc: str = None, dpi: int = 600, tight_layout: bool = True +) -> None: """Save figure. :param fig: figure diff --git a/swarm_visualizer/utility/legendplot_utils.py b/swarm_visualizer/utility/legendplot_utils.py index 6cb9783..9316c80 100644 --- a/swarm_visualizer/utility/legendplot_utils.py +++ b/swarm_visualizer/utility/legendplot_utils.py @@ -1,6 +1,9 @@ import matplotlib.pyplot as plt import seaborn as sns import matplotlib.pylab as pylab +import matplotlib as mpl +from matplotlib.cm import get_cmap +import numpy as np from typing import List, Dict from swarm_visualizer.utility import save_fig @@ -111,3 +114,74 @@ def create_seperate_legend( # Saves legend save_fig(figLegend, save_loc, dpi=600) + + +def create_colorbar( + ax, + all_labels, + title=None, + visible_labels=None, + colors=None, + palette=None, + discrete=False, + orientation="vertical", +): + """Creates a custom colorbar. + + Args: + ax: Axis to plot the colorbar + all_labels: All labels in the dataset + visible_labels: Labels that are visible in the plot + colors: Colors of the labels + palette: Palette of the colors + discrete: If the colorbar is discrete + orientation: Orientation of the colorbar + """ + + # If colors and palette is given at the same time raise an error + if colors and palette: + raise ValueError( + "Both colors and palette cannot be given at the same time." + ) + + # If discrete is False and colors are given raise an error + if not discrete and colors: + raise ValueError("Colors can only be given if discrete is True.") + + # If discrete and colors are not given, get the colors from the palette + if discrete and not colors: + colors = sns.color_palette(palette) + + if discrete: + cmaplist = [colors[i] for i in range(len(all_labels))] + cmap = mpl.colors.LinearSegmentedColormap.from_list( + "Custom cmap", cmaplist, len(cmaplist) + ) + bounds = np.arange(len(all_labels) + 1) - 0.5 + norm = mpl.colors.BoundaryNorm(bounds, cmap.N) + sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) + sm.set_array([]) + cbar = plt.colorbar(sm, cax=ax, ticks=np.arange(len(all_labels))) + + if visible_labels: + # Location of the visible labels + visible_labels_location = [ + all_labels.index(label) for label in visible_labels + ] + + # Set the ticks and labels + cbar.set_ticks(ticks=visible_labels_location, labels=visible_labels) + else: + cbar.set_ticks(ticks=np.arange(len(all_labels)), labels=all_labels) + + else: + cmap = get_cmap(palette) + norm = plt.Normalize(min(all_labels), max(all_labels)) + cbar = mpl.colorbar.ColorbarBase( + ax, cmap=cmap, norm=norm, orientation=orientation + ) + if visible_labels: + cbar.set_ticks(ticks=visible_labels) + + if title: + cbar.set_label(title) diff --git a/swarm_visualizer/utility/statistics_utils.py b/swarm_visualizer/utility/statistics_utils.py index 9d7fa5a..c86b283 100644 --- a/swarm_visualizer/utility/statistics_utils.py +++ b/swarm_visualizer/utility/statistics_utils.py @@ -1,7 +1,6 @@ from statannotations.Annotator import Annotator - def add_wilcoxon_value( ax, df=None, @@ -19,8 +18,17 @@ def add_wilcoxon_value( show_test_name=False, **kwargs ) -> None: - - annotator = Annotator(ax,box_pairs,data=df, x=x_var, y=y_var, hue = hue, order=order_list) - annotator.configure(test=test_type, text_format=text_format, loc=loc, verbose=verbose, fontsize=fontsize, - pvalue_format_string=pvalue_format_string, show_test_name=show_test_name, **kwargs) + annotator = Annotator( + ax, box_pairs, data=df, x=x_var, y=y_var, hue=hue, order=order_list + ) + annotator.configure( + test=test_type, + text_format=text_format, + loc=loc, + verbose=verbose, + fontsize=fontsize, + pvalue_format_string=pvalue_format_string, + show_test_name=show_test_name, + **kwargs + ) annotator.apply_and_annotate() diff --git a/swarm_visualizer/utility/textfile_utils.py b/swarm_visualizer/utility/textfile_utils.py index 7715cec..a3d924a 100644 --- a/swarm_visualizer/utility/textfile_utils.py +++ b/swarm_visualizer/utility/textfile_utils.py @@ -5,7 +5,6 @@ """ - import matplotlib import numpy as np diff --git a/swarm_visualizer/violinplot.py b/swarm_visualizer/violinplot.py index de6278b..ab80384 100644 --- a/swarm_visualizer/violinplot.py +++ b/swarm_visualizer/violinplot.py @@ -12,7 +12,7 @@ def plot_grouped_violinplot( title_str=None, order_list=None, pal=None, - **kwargs + **kwargs, ) -> None: """Plots a grouped violinplot. @@ -29,19 +29,35 @@ def plot_grouped_violinplot( # Plots a violinplot if not pal: if order_list: - sns.violinplot(x=x_var, y=y_var, data=df, order=order_list, ax=ax, **kwargs) + sns.violinplot( + x=x_var, y=y_var, data=df, order=order_list, ax=ax, **kwargs + ) else: - sns.violinplot(x=x_var, y=y_var, data=df, order=order_list, ax=ax, **kwargs) + sns.violinplot( + x=x_var, y=y_var, data=df, order=order_list, ax=ax, **kwargs + ) # Plots a violinplot with palette if pal: if order_list: sns.violinplot( - x=x_var, y=y_var, data=df, order=order_list, palette=pal, ax=ax, **kwargs + x=x_var, + y=y_var, + data=df, + order=order_list, + palette=pal, + ax=ax, + **kwargs, ) else: sns.violinplot( - x=x_var, y=y_var, data=df, order=order_list, palette=pal, ax=ax, **kwargs + x=x_var, + y=y_var, + data=df, + order=order_list, + palette=pal, + ax=ax, + **kwargs, ) # Set axis infos @@ -58,7 +74,7 @@ def plot_paired_violinplot( order_list=None, pal=None, hue=None, - **kwargs + **kwargs, ) -> None: """Plots a paired boxplot. @@ -78,11 +94,23 @@ def plot_paired_violinplot( # Plots a boxplot with order if order_list: sns.violinplot( - x=x_var, y=y_var, data=df, order=order_list, hue=hue, ax=ax, **kwargs + x=x_var, + y=y_var, + data=df, + order=order_list, + hue=hue, + ax=ax, + **kwargs, ) else: sns.violinplot( - x=x_var, y=y_var, data=df, order=order_list, hue=hue, ax=ax, **kwargs + x=x_var, + y=y_var, + data=df, + order=order_list, + hue=hue, + ax=ax, + **kwargs, ) # Plots a boxplot with palette @@ -96,7 +124,7 @@ def plot_paired_violinplot( palette=pal, hue=hue, ax=ax, - **kwargs + **kwargs, ) else: sns.violinplot( @@ -107,7 +135,7 @@ def plot_paired_violinplot( palette=pal, hue=hue, ax=ax, - **kwargs + **kwargs, ) # Set axis infos diff --git a/tests/processing/__init__.py b/tests/processing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/processing/test_grouping.py b/tests/processing/test_grouping.py new file mode 100644 index 0000000..c682a2a --- /dev/null +++ b/tests/processing/test_grouping.py @@ -0,0 +1,40 @@ +import numpy as np +import pytest + +from swarm_visualizer.processing import group_values_by_bound + + +def test_group_values_by_upper_bound(): + values = np.array([0.05, 0.15, 0.25, 0.8]) + expected = np.array([0.1, 0.2, 0.3, 0.9]) + assert np.array_equal( + group_values_by_bound(values, step=0.1, min_value=0, max_value=1), + expected, + ) + + +def test_group_values_by_lower_bound(): + values = np.array([0.05, 0.15, 0.25, 0.8]) + expected = np.array([0.0, 0.1, 0.2, 0.8]) + assert np.array_equal( + group_values_by_bound( + values, step=0.1, min_value=0, max_value=1, use_upper_bound=False + ), + expected, + ) + + +def test_empty_input(): + values = np.array([]) + expected = np.array([]) + assert np.array_equal( + group_values_by_bound(values, step=0.1, min_value=0, max_value=1), + expected, + ) + + +def test_non_numeric_input(): + with pytest.raises(TypeError): + group_values_by_bound( + ["a", "b", "c"], step=0.1, min_value=0, max_value=1 + ) diff --git a/tests/processing/test_sorting.py b/tests/processing/test_sorting.py new file mode 100644 index 0000000..5dbbfb7 --- /dev/null +++ b/tests/processing/test_sorting.py @@ -0,0 +1,59 @@ +import numpy as np +import pytest + +from swarm_visualizer.processing import ( + sort_array_based_on_reference_array, # Replace 'your_module' with the actual name of your module +) + + +# Define a fixture for the sample data +@pytest.fixture +def sample_data(): + data = { + "prices": np.array([20.95, 15.99, 22.50, 18.00]), + "ratings": np.array([4.5, 3.8, 4.9, 4.0]), + "stocks": np.array([34, 42, 25, 50]), + } + return data + + +def test_sort_ascending(sample_data): + """Test sorting arrays in ascending order based on ratings.""" + sorted_arrays = sort_array_based_on_reference_array( + [sample_data["prices"], sample_data["stocks"]], + sample_data["ratings"], + "ascending", + ) + expected_ratings_sorted = np.array([3.8, 4.0, 4.5, 4.9]) + expected_prices_sorted = np.array([15.99, 18.00, 20.95, 22.50]) + expected_stocks_sorted = np.array([42, 50, 34, 25]) + + np.testing.assert_array_equal(sorted_arrays[0], expected_ratings_sorted) + np.testing.assert_array_equal(sorted_arrays[1], expected_prices_sorted) + np.testing.assert_array_equal(sorted_arrays[2], expected_stocks_sorted) + + +def test_sort_descending(sample_data): + """Test sorting arrays in descending order based on ratings.""" + sorted_arrays = sort_array_based_on_reference_array( + [sample_data["prices"], sample_data["stocks"]], + sample_data["ratings"], + "descending", + ) + expected_ratings_sorted = np.array([4.9, 4.5, 4.0, 3.8]) + expected_prices_sorted = np.array([22.50, 20.95, 18.00, 15.99]) + expected_stocks_sorted = np.array([25, 34, 50, 42]) + + np.testing.assert_array_equal(sorted_arrays[0], expected_ratings_sorted) + np.testing.assert_array_equal(sorted_arrays[1], expected_prices_sorted) + np.testing.assert_array_equal(sorted_arrays[2], expected_stocks_sorted) + + +def test_length_mismatch(sample_data): + """Test error raising when input arrays do not have the same length.""" + with pytest.raises(ValueError): + sort_array_based_on_reference_array( + [sample_data["prices"], np.array([1, 2])], + sample_data["ratings"], + "ascending", + ) diff --git a/tests/test_barplot.py b/tests/test_barplot.py index 93ef81f..21e7493 100644 --- a/tests/test_barplot.py +++ b/tests/test_barplot.py @@ -7,7 +7,11 @@ import pandas as pd import pytest -from swarm_visualizer import plot_grouped_barplot, plot_stacked_barplot, plot_sns_grouped_barplot +from swarm_visualizer import ( + plot_grouped_barplot, + plot_stacked_barplot, + plot_sns_grouped_barplot, +) from swarm_visualizer.utility import save_fig, set_plot_properties @@ -23,16 +27,19 @@ ) _HUE_DATA_FRAME = pd.DataFrame( - {"$y_1$": _X1_DATA[:4], - "$y_2$": _X1_DATA[:4] + 0.02, - "$x$": ["x_1", "x_1", "x_2", "x_2"], - "group": ["a", "b", "a", "b"]} + { + "$y_1$": _X1_DATA[:4], + "$y_2$": _X1_DATA[:4] + 0.02, + "$x$": ["x_1", "x_1", "x_2", "x_2"], + "group": ["a", "b", "a", "b"], + } ) _SAVE_LOC = os.path.abspath( os.path.join(os.path.dirname(__file__), "example_plots") ) + @pytest.mark.parametrize( ("df", "x_var", "y_var", "y_label"), [(_DATA_FRAME, "$x$", ["$y_1$", "$y_2$"], "$y$")], @@ -96,6 +103,7 @@ def test_grouped_barplot(df, x_var, y_var, y_label) -> None: save_loc = os.path.join(_SAVE_LOC, "barplot", "grouped_barplot.png") save_fig(fig, save_loc, dpi=600) + @pytest.mark.parametrize( ("df", "x_var", "y_var", "y_label", "hue"), [(_HUE_DATA_FRAME, "$x$", "$y_1$", "$y$", "group")], @@ -127,4 +135,4 @@ def test_sns_grouped_barplot(df, x_var, y_var, y_label, hue) -> None: # Save the plot save_loc = os.path.join(_SAVE_LOC, "barplot", "hued_barplot.png") - save_fig(fig, save_loc, dpi=600) \ No newline at end of file + save_fig(fig, save_loc, dpi=600) diff --git a/tests/test_colorbar.py b/tests/test_colorbar.py new file mode 100644 index 0000000..087ed0a --- /dev/null +++ b/tests/test_colorbar.py @@ -0,0 +1,330 @@ +from __future__ import annotations + +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pytest +import seaborn as sns + +from swarm_visualizer import plot_basic_scatterplot + +from swarm_visualizer.utility import ( + save_fig, + set_plot_properties, + create_colorbar, +) + +# Example Plots location +_X_DATA = np.arange(0, 10, 0.1) + np.random.normal(0, 0.1, 100) +_Y_DATA = np.arange(0, 10, 0.1) + np.random.normal(0, 0.1, 100) + +_DISCRETE_NUMERICAL_LABEL_DATA = [i // 10 for i in range(100)] +_DISCRETE_CATEGORY_LABEL_DATA = ["Group " + str(i // 10) for i in range(100)] +_CONTINUOUS_LABEL_DATA = [np.random.rand() + i for i in range(100)] + +_DATA_FRAME = pd.DataFrame( + { + "x": _X_DATA, + "y": _Y_DATA, + "discrete_numerical_label": _DISCRETE_NUMERICAL_LABEL_DATA, + "discrete_category_label": _DISCRETE_CATEGORY_LABEL_DATA, + "continuous_label": _CONTINUOUS_LABEL_DATA, + } +) + +_DISCRETE_CATEGORY_VISIBLE_LABEL_DATA = [ + "Group 1", + "Group 5", + "Group 6", + "Group 9", +] +_CONTINUOUS_VISIBLE_LABEL_DATA = [2.5, 45, 88, 50] + +_CONTINUOUS_PALETTE = "magma" +_DISCRETE_PALETTE = "tab20" +_DISCRETE_COLORS = [ + "red", + "blue", + "green", + "yellow", + "black", + "purple", + "orange", + "brown", + "pink", + "gray", +] + + +_SAVE_LOC = os.path.abspath( + os.path.join(os.path.dirname(__file__), "example_plots") +) + + +@pytest.mark.parametrize( + ("x_data", "y_data", "discrete_numerical_label_data", "discrete_palette"), + [(_X_DATA, _Y_DATA, _DISCRETE_NUMERICAL_LABEL_DATA, _DISCRETE_PALETTE)], +) +def test_basic_discrete_colorbar( + x_data, y_data, discrete_numerical_label_data, discrete_palette +) -> None: + """Tests basic discrete colorbar. + + :param x_data: x-axis data + :param y_data: y-axis data + :param discrete_numerical_label_data: discrete numerical label data + :param discrete_palette: discrete palette + :return: None + """ + # Sets plot style + set_plot_properties() + + fig, ax = plt.subplots( + figsize=(12, 10), + nrows=1, + ncols=2, + dpi=600, + gridspec_kw={"width_ratios": [1, 0.03]}, + ) + + # Unique values in the discrete numerical label data + unique_values = np.unique(discrete_numerical_label_data) + + # Create a basic scatter plot first + palette_colors = sns.color_palette(discrete_palette, len(unique_values)) + + # Find colors for each value in the discrete numerical label data + colors = [ + palette_colors[unique_values.tolist().index(i)] + for i in discrete_numerical_label_data + ] + + plot_basic_scatterplot( + ax[0], + x=x_data, + y=y_data, + color=colors, + title_str="Basic Scatter Plot with Discrete Colorbar", + ylabel="$y$", + lw=3.0, + ylim=None, + xlabel="$x$", + xlim=None, + ms=10.0, + ) + + create_colorbar( + ax[1], unique_values, palette=discrete_palette, discrete=True + ) + + # Save the plot + save_loc = os.path.join( + _SAVE_LOC, "colorbars", "discrete_colorbar_with_all_labels.png" + ) + save_fig(fig, save_loc, dpi=600) + + +@pytest.mark.parametrize( + ( + "x_data", + "y_data", + "discrete_category_label_data", + "discrete_category_visible_label_data", + "discrete_colors", + ), + [ + ( + _X_DATA, + _Y_DATA, + _DISCRETE_CATEGORY_LABEL_DATA, + _DISCRETE_CATEGORY_VISIBLE_LABEL_DATA, + _DISCRETE_COLORS, + ) + ], +) +def test_discrete_colorbar_w_limited_visible( + x_data, + y_data, + discrete_category_label_data, + discrete_category_visible_label_data, + discrete_colors, +) -> None: + """Tests basic discrete colorbar. + + :param x_data: x-axis data + :param y_data: y-axis data + :param discrete_category_label_data: discrete categorical label data + :param discrete_category_visible_label_data: discrete categorical label data + :param discrete_palette: discrete palette + :return: None + """ + # Sets plot style + set_plot_properties() + + fig, ax = plt.subplots( + figsize=(12, 10), + nrows=1, + ncols=2, + dpi=600, + gridspec_kw={"width_ratios": [1, 0.03]}, + ) + + # Unique values in the discrete numerical label data + unique_values = np.unique(discrete_category_label_data).tolist() + + # Find colors for each value in the discrete numerical label data + colors = [ + discrete_colors[unique_values.index(i)] + for i in discrete_category_label_data + ] + + plot_basic_scatterplot( + ax[0], + x=x_data, + y=y_data, + color=colors, + title_str="Basic Scatter Plot with Discrete Colorbar", + ylabel="$y$", + lw=3.0, + ylim=None, + xlabel="$x$", + xlim=None, + ms=10.0, + ) + + create_colorbar( + ax[1], + unique_values, + visible_labels=discrete_category_visible_label_data, + colors=discrete_colors, + discrete=True, + ) + + # Save the plot + save_loc = os.path.join( + _SAVE_LOC, "colorbars", "discrete_colorbar_with_some_labels.png" + ) + save_fig(fig, save_loc, dpi=600) + + +@pytest.mark.parametrize( + ("data_frame", "x_var", "y_var", "hue", "continuous_palette"), + [(_DATA_FRAME, "x", "y", "continuous_label", _CONTINUOUS_PALETTE)], +) +def test_basic_continous_colorbar( + data_frame, x_var, y_var, hue, continuous_palette +) -> None: + """Tests basic discrete colorbar. + + :param data_frame: data frame + :param x_var: x-axis variable + :param y_var: y-axis variable + :param hue: hue variable + :param continuous_palette: continuous palette + :return: None + """ + # Sets plot style + set_plot_properties() + + fig, ax = plt.subplots( + figsize=(12, 10), + nrows=1, + ncols=2, + dpi=600, + gridspec_kw={"width_ratios": [1, 0.03]}, + ) + + sns.scatterplot( + data=data_frame, + x=x_var, + y=y_var, + hue=hue, + palette=continuous_palette, + ax=ax[0], + s=100, + legend=False, + ) + + create_colorbar( + ax[1], + data_frame[hue], + title="Continuous Colorbar", + palette=continuous_palette, + ) + + # Save the plot + save_loc = os.path.join(_SAVE_LOC, "colorbars", "continous_colorbar.png") + save_fig(fig, save_loc, dpi=600) + + +@pytest.mark.parametrize( + ( + "data_frame", + "x_var", + "y_var", + "hue", + "continuous_visible", + "continuous_palette", + ), + [ + ( + _DATA_FRAME, + "x", + "y", + "continuous_label", + _CONTINUOUS_VISIBLE_LABEL_DATA, + _CONTINUOUS_PALETTE, + ) + ], +) +def test_continous_colorbar_w_limited_visible( + data_frame, x_var, y_var, hue, continuous_visible, continuous_palette +) -> None: + """Tests basic discrete colorbar. + + :param data_frame: data frame + :param x_var: x-axis variable + :param y_var: y-axis variable + :param hue: hue variable + :param continuous_visible: continuous visible labels + :param continuous_palette: continuous palette + """ + # Sets plot style + set_plot_properties() + + fig, ax = plt.subplots( + figsize=(10, 12), + nrows=2, + ncols=1, + dpi=600, + gridspec_kw={"height_ratios": [1, 0.03]}, + ) + + sns.scatterplot( + data=data_frame, + x=x_var, + y=y_var, + hue=hue, + palette=continuous_palette, + ax=ax[0], + s=100, + legend=False, + ) + + create_colorbar( + ax[1], + data_frame[hue], + title="Continuous Colorbar", + visible_labels=continuous_visible, + palette=continuous_palette, + orientation="horizontal", + ) + + # Save the plot + save_loc = os.path.join( + _SAVE_LOC, "colorbars", "continous_limited_data_colorbar.png" + ) + plt.tight_layout() + save_fig(fig, save_loc, dpi=600) diff --git a/tests/test_histogram.py b/tests/test_histogram.py index 59f789d..5879fdc 100644 --- a/tests/test_histogram.py +++ b/tests/test_histogram.py @@ -33,7 +33,7 @@ def test_plot_pdf(data) -> None: """Tests plot pdf. - :param data: data + :param data: data :return: None """ # Sets plot style @@ -53,7 +53,7 @@ def test_plot_pdf(data) -> None: def test_plot_several_pdf(data_list) -> None: """Tests plotting several pdf in the same plot. - :param data_list: list of data + :param data_list: list of data :return: None """ # Sets plot style diff --git a/tests/test_p_value.py b/tests/test_p_value.py index acbf166..9376bb6 100644 --- a/tests/test_p_value.py +++ b/tests/test_p_value.py @@ -28,9 +28,7 @@ ] ) -_DATA_FRAME = pd.DataFrame( - {"$y$": _X_DATA, "$x$": _X_LABEL, "hue": _GROUPS} -) +_DATA_FRAME = pd.DataFrame({"$y$": _X_DATA, "$x$": _X_LABEL, "hue": _GROUPS}) _SAVE_LOC = os.path.abspath( os.path.join(os.path.dirname(__file__), "example_plots") )