diff --git a/environment.yml b/environment.yml index cc47e9b..33e3fd0 100644 --- a/environment.yml +++ b/environment.yml @@ -12,3 +12,4 @@ dependencies: # streamlit dependencies - streamlit>=1.38.0 - captcha==0.5.0 + - pyopenms_viz>=0.1.2 diff --git a/requirements.txt b/requirements.txt index a590255..20a7ee9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,8 @@ # note that it is much more restricted in terms of installing third-parties / etc. # preferably use the batteries included or simple docker file for local hosting streamlit>=1.38.0 -pyopenms==3.1.0 +pyopenms==3.2.0 numpy==1.26.4 # pandas and numpy are dependencies of pyopenms, however, pyopenms needs numpy<=1.26.4 plotly==5.22.0 -captcha==0.5.0 \ No newline at end of file +captcha==0.5.0 +pyopenms_viz>=0.1.2 \ No newline at end of file diff --git a/src/common/common.py b/src/common/common.py index 222cbb2..4c4321e 100644 --- a/src/common/common.py +++ b/src/common/common.py @@ -299,6 +299,14 @@ def change_workspace(): img_formats.index(params["image-format"]), key="image-format", ) + st.markdown("## Spectrum Plotting") + st.selectbox("Bin Peaks", ["auto", True, False], key="spectrum_bin_peaks") + if st.session_state["spectrum_bin_peaks"] == True: + st.number_input( + "Number of Bins (m/z)", 1, 10000, 50, key="spectrum_num_bins" + ) + else: + st.session_state["spectrum_num_bins"] = 50 return params @@ -321,7 +329,7 @@ def v_space(n: int, col=None) -> None: def display_large_dataframe( - df, chunk_sizes: list[int] = [100, 1_000, 10_000], **kwargs + df, chunk_sizes: list[int] = [10, 100, 1_000, 10_000], **kwargs ): """ Displays a large DataFrame in chunks with pagination controls and row selection. @@ -332,23 +340,20 @@ def display_large_dataframe( ...: Additional keyword arguments to pass to the `st.dataframe` function. See: https://docs.streamlit.io/develop/api-reference/data/st.dataframe Returns: - Selected rows from the current chunk. + Index of selected row. """ - def update_on_change(): - # Initialize session state for pagination - if "current_chunk" not in st.session_state: - st.session_state.current_chunk = 0 - st.session_state.current_chunk = 0 - # Dropdown for selecting chunk size - chunk_size = st.selectbox( - "Select Number of Rows to Display", chunk_sizes, on_change=update_on_change - ) + chunk_size = st.selectbox("Select Number of Rows to Display", chunk_sizes) # Calculate total number of chunks total_chunks = (len(df) + chunk_size - 1) // chunk_size + if total_chunks > 1: + page = int(st.number_input("Select Page", 1, total_chunks, 1, step=1)) + else: + page = 1 + # Function to get the current chunk of the DataFrame def get_current_chunk(df, chunk_size, chunk_index): start = chunk_index * chunk_size @@ -358,9 +363,7 @@ def get_current_chunk(df, chunk_size, chunk_index): return df.iloc[start:end], start, end # Display the current chunk - current_chunk_df, start_row, end_row = get_current_chunk( - df, chunk_size, st.session_state.current_chunk - ) + current_chunk_df, start_row, end_row = get_current_chunk(df, chunk_size, page - 1) event = st.dataframe(current_chunk_df, **kwargs) @@ -368,20 +371,13 @@ def get_current_chunk(df, chunk_size, chunk_index): f"Showing rows {start_row + 1} to {end_row} of {len(df)} ({get_dataframe_mem_useage(current_chunk_df):.2f} MB)" ) - # Pagination buttons - col1, col2, col3 = st.columns([1, 2, 1]) - - with col1: - if st.button("Previous") and st.session_state.current_chunk > 0: - st.session_state.current_chunk -= 1 - - with col3: - if st.button("Next") and st.session_state.current_chunk < total_chunks - 1: - st.session_state.current_chunk += 1 + rows = event["selection"]["rows"] + if not rows: + return None + # Calculate the index based on the current page and chunk size + base_index = (page - 1) * chunk_size + return base_index + rows[0] - if event is not None: - return event - return None def show_table(df: pd.DataFrame, download_name: str = "") -> None: diff --git a/src/plotting/BasePlotter.py b/src/plotting/BasePlotter.py deleted file mode 100644 index 12a30f0..0000000 --- a/src/plotting/BasePlotter.py +++ /dev/null @@ -1,58 +0,0 @@ -from abc import ABC, abstractmethod -from dataclasses import dataclass -from enum import Enum -from typing import Literal, List -import numpy as np - -# A colorset suitable for color blindness -class Colors(str, Enum): - BLUE = "#4575B4" - RED = "#D73027" - LIGHTBLUE = "#91BFDB" - ORANGE = "#FC8D59" - PURPLE = "#7B2C65" - YELLOW = "#FCCF53" - DARKGRAY = "#555555" - LIGHTGRAY = "#BBBBBB" - - -@dataclass(kw_only=True) -class _BasePlotterConfig(ABC): - title: str = "1D Plot" - xlabel: str = "X-axis" - ylabel: str = "Y-axis" - height: int = 500 - width: int = 500 - relative_intensity: bool = False - show_legend: bool = True - - -# Abstract Class for Plotting -class _BasePlotter(ABC): - def __init__(self, config: _BasePlotterConfig) -> None: - self.config = config - self.fig = None # holds the figure object - - def updateConfig(self, **kwargs): - for key, value in kwargs.items(): - if hasattr(self.config, key): - setattr(self.config, key, value) - else: - raise ValueError(f"Invalid config setting: {key}") - - def _get_n_grayscale_colors(self, n: int) -> List[str]: - """Returns n evenly spaced grayscale colors in hex format.""" - hex_list = [] - for v in np.linspace(50, 200, n): - hex = "#" - for _ in range(3): - hex += f"{int(round(v)):02x}" - hex_list.append(hex) - return hex_list - - def plot(self, data, **kwargs): - return self._plot(data, **kwargs) - - @abstractmethod - def _plot(self, data, **kwargs): - pass \ No newline at end of file diff --git a/src/plotting/MSExperimentPlotter.py b/src/plotting/MSExperimentPlotter.py deleted file mode 100644 index c42df2c..0000000 --- a/src/plotting/MSExperimentPlotter.py +++ /dev/null @@ -1,221 +0,0 @@ -from dataclasses import dataclass -from typing import Literal, Union - -import matplotlib.pyplot as plt -import pandas as pd -import numpy as np -import plotly.graph_objects as go - -from src.plotting.BasePlotter import Colors, _BasePlotter, _BasePlotterConfig - - -@dataclass(kw_only=True) -class MSExperimentPlotterConfig(_BasePlotterConfig): - bin_peaks: Union[Literal["auto"], bool] = "auto" - num_RT_bins: int = 50 - num_mz_bins: int = 50 - plot3D: bool = False - title: str = "Peak Map" - xlabel: str = "RT (s)" - ylabel: str = "m/z" - height: int = 500 - width: int = 750 - relative_intensity: bool = False - show_legend: bool = True - - -class MSExperimentPlotter(_BasePlotter): - def __init__(self, config: MSExperimentPlotterConfig, **kwargs) -> None: - """ - Initialize the MSExperimentPlotter with a given configuration and optional parameters. - - Args: - config (MSExperimentPlotterConfig): Configuration settings for the spectrum plotter. - **kwargs: Additional keyword arguments for customization. - """ - super().__init__(config=config, **kwargs) - - def _prepare_data(self, exp: pd.DataFrame) -> pd.DataFrame: - """Prepares data for plotting based on configuration (binning, relative intensity, hover text).""" - if self.config.bin_peaks == True or ( - exp.shape[0] > self.config.num_mz_bins * self.config.num_RT_bins - and self.config.bin_peaks == "auto" - ): - exp["mz"] = pd.cut(exp["mz"], bins=self.config.num_mz_bins) - exp["RT"] = pd.cut(exp["RT"], bins=self.config.num_RT_bins) - - # Group by x and y bins and calculate the mean intensity within each bin - exp = ( - exp.groupby(["mz", "RT"], observed=True) - .agg({"inty": "mean"}) - .reset_index() - ) - exp["mz"] = exp["mz"].apply(lambda interval: interval.mid).astype(float) - exp["RT"] = exp["RT"].apply(lambda interval: interval.mid).astype(float) - exp = exp.fillna(0) - else: - self.config.bin_peaks = False - - if self.config.relative_intensity: - exp["inty"] = exp["inty"] / max(exp["inty"]) * 100 - - exp["hover_text"] = exp.apply( - lambda x: f"m/z: {round(x['mz'], 6)}
RT: {round(x['RT'], 2)}
intensity: {int(x['inty'])}", - axis=1, - ) - - return exp.sort_values("inty") - - def _plotMatplotlib3D( - self, - exp: pd.DataFrame, - ) -> plt.Figure: - """Plot 3D peak map with mz, RT and intensity dimensions. Colored peaks based on intensity.""" - fig = plt.figure( - figsize=(self.config.width / 100, self.config.height / 100), - layout="constrained", - ) - ax = fig.add_subplot(111, projection="3d") - - if self.config.title: - ax.set_title(self.config.title, fontsize=12, loc="left") - ax.set_xlabel( - self.config.ylabel, - fontsize=9, - labelpad=-2, - color=Colors["DARKGRAY"], - style="italic", - ) - ax.set_ylabel( - self.config.xlabel, - fontsize=9, - labelpad=-2, - color=Colors["DARKGRAY"], - ) - ax.set_zlabel("intensity", fontsize=10, color=Colors["DARKGRAY"], labelpad=-2) - for axis in ("x", "y", "z"): - ax.tick_params(axis=axis, labelsize=8, pad=-2, colors=Colors["DARKGRAY"]) - - ax.set_box_aspect(aspect=None, zoom=0.88) - ax.ticklabel_format(axis="z", style="sci", useMathText=True, scilimits=(0,0)) - ax.grid(color="#FF0000", linewidth=0.8) - ax.xaxis.pane.fill = False - ax.yaxis.pane.fill = False - ax.zaxis.pane.fill = False - ax.view_init(elev=25, azim=-45, roll=0) - - # Plot lines to the bottom with colored based on inty - for i in range(len(exp)): - ax.plot( - [exp["RT"].iloc[i], exp["RT"].iloc[i]], - [exp["inty"].iloc[i], 0], - [exp["mz"].iloc[i], exp["mz"].iloc[i]], - zdir="x", - color=plt.cm.magma_r((exp["inty"].iloc[i] / exp["inty"].max())), - ) - return fig - - def _plotPlotly2D( - self, - exp: pd.DataFrame, - ) -> go.Figure: - """Plot 2D peak map with mz and RT dimensions. Colored peaks based on intensity.""" - layout = go.Layout( - title=dict(text=self.config.title), - xaxis=dict(title=self.config.xlabel), - yaxis=dict(title=self.config.ylabel), - showlegend=self.config.show_legend, - template="simple_white", - dragmode="select", - height=self.config.height, - width=self.config.width, - ) - fig = go.Figure(layout=layout) - fig.add_trace( - go.Scattergl( - name="peaks", - x=exp["RT"], - y=exp["mz"], - mode="markers", - marker=dict( - color=exp["inty"].apply(lambda x: np.log(x)), - colorscale="sunset", - size=8, - symbol="square", - colorbar=( - dict(thickness=8, outlinewidth=0, tickformat=".0e") - if self.config.show_legend - else None - ), - ), - hovertext=exp["hover_text"] if not self.config.bin_peaks else None, - hoverinfo="text", - showlegend=False, - ) - ) - return fig - - def _plot( - self, - exp: pd.DataFrame, - ) -> go.Figure: - """Prepares data and returns Plotly 2D plot or Matplotlib 3D plot.""" - exp = self._prepare_data(exp) - if self.config.plot3D: - return self._plotMatplotlib3D(exp) - return self._plotPlotly2D(exp) - -# ============================================================================= # -## FUNCTIONAL API ## -# ============================================================================= # - - -def plotMSExperiment( - exp: pd.DataFrame, - plot3D: bool = False, - relative_intensity: bool = False, - bin_peaks: Union[Literal["auto"], bool] = "auto", - num_RT_bins: int = 50, - num_mz_bins: int = 50, - width: int = 750, - height: int = 500, - title: str = "Peak Map", - xlabel: str = "RT (s)", - ylabel: str = "m/z", - show_legend: bool = False, -): - """ - Plots a Spectrum from an MSSpectrum object - - Args: - spectrum (pd.DataFrame): OpenMS MSSpectrum Object - plot3D: (bool = False, optional): Plot peak map 3D with peaks colored based on intensity. Disables colorbar legend. Works with "MATPLOTLIB" engine only. Defaults to False. - relative_intensity (bool, optional): If true, plot relative intensity values. Defaults to False. - bin_peaks: (Union[Literal["auto"], bool], optional): Bin peaks to reduce complexity and improve plotting speed. Hovertext disabled if activated. If set to "auto" any MSExperiment with more then num_RT_bins x num_mz_bins peaks will be binned. Defaults to "auto". - num_RT_bins: (int, optional): Number of bins in RT dimension. Defaults to 50. - num_mz_bins: (int, optional): Number of bins in m/z dimension. Defaults to 50. - width (int, optional): Width of plot. Defaults to 500px. - height (int, optional): Height of plot. Defaults to 500px. - title (str, optional): Plot title. Defaults to "Spectrum Plot". - xlabel (str, optional): X-axis label. Defaults to "m/z". - ylabel (str, optional): Y-axis label. Defaults to "intensity" or "ion mobility". - show_legend (int, optional): Show legend. Defaults to False. - - Returns: - Plot: The generated plot using the specified engine. - """ - config = MSExperimentPlotterConfig( - plot3D=plot3D, - relative_intensity=relative_intensity, - bin_peaks=bin_peaks, - num_RT_bins=num_RT_bins, - num_mz_bins=num_mz_bins, - width=width, - height=height, - title=title, - xlabel=xlabel, - ylabel=ylabel, - show_legend=show_legend, - ) - plotter = MSExperimentPlotter(config) - return plotter.plot(exp.copy()) \ No newline at end of file diff --git a/src/view.py b/src/view.py index ad35a64..1b89d5c 100644 --- a/src/view.py +++ b/src/view.py @@ -5,10 +5,7 @@ import plotly.graph_objects as go import streamlit as st import pyopenms as poms - -from src.plotting.MSExperimentPlotter import plotMSExperiment from src.common.common import show_fig, display_large_dataframe - from typing import Union @@ -61,6 +58,7 @@ def get_df(file: Union[str, Path]) -> pd.DataFrame: else: st.session_state["view_ms2"] = pd.DataFrame() + def plot_bpc_tic() -> go.Figure: """Plot the base peak and total ion chromatogram (TIC). @@ -68,26 +66,41 @@ def plot_bpc_tic() -> go.Figure: A plotly Figure object containing the BPC and TIC plot. """ fig = go.Figure() + max_int = 0 if st.session_state.view_tic: df = st.session_state.view_ms1.groupby("RT").sum().reset_index() - fig.add_scatter( - x=df["RT"], - y=df["inty"], - mode="lines", - line=dict(color="#f24c5c", width=3), # OpenMS red - name="TIC", - showlegend=True, + df["type"] = "TIC" + if df["inty"].max() > max_int: + max_int = df["inty"].max() + fig = df.plot( + backend="ms_plotly", + kind="chromatogram", + fig=fig, + x="RT", + y="inty", + by="type", + line_color="#f24c5c", + show_plot=False, + grid=False, ) + fig = fig.fig if st.session_state.view_bpc: df = st.session_state.view_ms1.groupby("RT").max().reset_index() - fig.add_scatter( - x=df["RT"], - y=df["inty"], - mode="lines", - line=dict(color="#2d3a9d", width=3), # OpenMS blue - name="BPC", - showlegend=True, + df["type"] = "BPC" + if df["inty"].max() > max_int: + max_int = df["inty"].max() + fig = df.plot( + backend="ms_plotly", + kind="chromatogram", + fig=fig, + x="RT", + y="inty", + by="type", + line_color="#2d3a9d", + show_plot=False, + grid=False, ) + fig = fig.fig if st.session_state.view_eic: df = st.session_state.view_ms1 target_value = st.session_state.view_eic_mz.strip().replace(",", ".") @@ -97,19 +110,30 @@ def plot_bpc_tic() -> go.Figure: tolerance = (target_value * ppm_tolerance) / 1e6 # Filter the DataFrame - df_eic = df[(df['mz'] >= target_value - tolerance) & (df['mz'] <= target_value + tolerance)] + df_eic = df[ + (df["mz"] >= target_value - tolerance) + & (df["mz"] <= target_value + tolerance) + ] if not df_eic.empty: - fig.add_scatter( - x=df_eic["RT"], - y=df_eic["inty"], - mode="lines", - line=dict(color="#f6bf26", width=3), - name="XIC", - showlegend=True, + df_eic["type"] = "XIC" + if df_eic["inty"].max() > max_int: + max_int = df_eic["inty"].max() + fig = df_eic.plot( + backend="ms_plotly", + kind="chromatogram", + fig=fig, + x="RT", + y="inty", + by="type", + line_color="#f6bf26", + show_plot=False, + grid=False, ) + fig = fig.fig except ValueError: st.error("Invalid m/z value for XIC provided. Please enter a valid number.") + fig.update_yaxes(range=[0, max_int]) fig.update_layout( title=f"{st.session_state.view_selected_file}", xaxis_title="retention time (s)", @@ -122,68 +146,22 @@ def plot_bpc_tic() -> go.Figure: @st.cache_resource -def plot_ms_spectrum(spec, title, color): - """ - Takes a pandas Series (spec) and generates a needle plot with m/z and intensity dimension. - - Args: - spec: Pandas Series representing the mass spectrum with "mzarray" and "intarray" columns. - title: Title of the plot. - color: Color of the line in the plot. - - Returns: - A Plotly Figure object representing the needle plot of the mass spectrum. - """ - - # Every Peak is represented by three dots in the line plot: (x, 0), (x, y), (x, 0) - def create_spectra(x, y, zero=0): - x = np.repeat(x, 3) - y = np.repeat(y, 3) - y[::3] = y[2::3] = zero - return pd.DataFrame({"mz": x, "intensity": y}) - - df = create_spectra(spec["mzarray"], spec["intarray"]) - fig = px.line(df, x="mz", y="intensity") - fig.update_traces(line_color=color) - fig.add_hline(0, line=dict(color="#DDDDDD"), line_width=3) - fig.update_layout( - showlegend=False, - title_text=title, - xaxis_title="m/z", - yaxis_title="intensity", - plot_bgcolor="rgb(255,255,255)", - dragmode="select", +def plot_ms_spectrum(df, title, bin_peaks, num_x_bins): + fig = df.plot( + kind="spectrum", + backend="ms_plotly", + x="mz", + y="intensity", + line_color="#2d3a9d", + title=title, + show_plot=False, + grid=False, + bin_peaks=bin_peaks, + num_x_bins=num_x_bins, ) - # add annotations - top_indices = np.argsort(spec["intarray"])[-5:][::-1] - for index in top_indices: - mz = spec["mzarray"][index] - i = spec["intarray"][index] - fig.add_annotation( - dict( - x=mz, - y=i, - text=str(round(mz, 5)), - showarrow=False, - xanchor="left", - font=dict( - family="Open Sans Mono, monospace", - size=12, - color=color, - ), - ) - ) - fig.layout.template = "plotly_white" - # adjust x-axis limits to not cut peaks and annotations - x_values = [trace.x for trace in fig.data] - xmin = min([min(values) for values in x_values]) - xmax = max([max(values) for values in x_values]) - padding = 0.15 * (xmax - xmin) + fig = fig.fig fig.update_layout( - xaxis_range=[ - xmin - padding, - xmax + padding, - ] + template="plotly_white", dragmode="select", plot_bgcolor="rgb(255,255,255)" ) return fig @@ -199,9 +177,19 @@ def view_peak_map(): df = df[df["mz"] > box[0]["y"][1]] df = df[df["mz"] < box[0]["y"][0]] df = df[df["RT"] < box[0]["x"][1]] - peak_map = plotMSExperiment( - df, plot3D=False, title=st.session_state.view_selected_file + peak_map = df.plot( + kind="peakmap", + x="RT", + y="mz", + z="inty", + title=st.session_state.view_selected_file, + grid=False, + show_plot=False, + bin_peaks=True, + backend="ms_plotly", ) + peak_map.fig.update_layout(template="simple_white", dragmode="select") + peak_map = peak_map.fig c1, c2 = st.columns(2) with c1: st.info( @@ -214,8 +202,24 @@ def view_peak_map(): ) with c2: if df.shape[0] < 2500: - peak_map_3D = plotMSExperiment(df, plot3D=True, title="") - st.pyplot(peak_map_3D, use_container_width=True) + peak_map_3D = df.plot( + kind="peakmap", + plot_3d=True, + backend="ms_plotly", + x="RT", + y="mz", + z="inty", + zlabel="Intensity", + title="", + show_plot=False, + grid=False, + bin_peaks=st.session_state.spectrum_bin_peaks, + num_x_bins=st.session_state.spectrum_num_bins, + height=650, + width=900, + ) + peak_map_3D = peak_map_3D.fig + st.plotly_chart(peak_map_3D, use_container_width=True) @st.fragment @@ -224,7 +228,7 @@ def view_spectrum(): with cols[0]: df = st.session_state.view_spectra.copy() df["spectrum ID"] = df.index + 1 - event = display_large_dataframe( + index = display_large_dataframe( df, column_order=[ "spectrum ID", @@ -238,10 +242,9 @@ def view_spectrum(): use_container_width=True, hide_index=True, ) - rows = event.selection.rows with cols[1]: - if rows: - df = st.session_state.view_spectra.iloc[rows[0]] + if index is not None: + df = st.session_state.view_spectra.iloc[index] if "view_spectrum_selection" in st.session_state: box = st.session_state.view_spectrum_selection.selection.box if box: @@ -251,10 +254,28 @@ def view_spectrum(): df["mzarray"] = df["mzarray"][mask] if df["mzarray"].size > 0: - title = f"{st.session_state.view_selected_file} spec={rows[0]+1} mslevel={df['MS level']}" + title = f"{st.session_state.view_selected_file} spec={index+1} mslevel={df['MS level']}" if df["precursor m/z"] > 0: title += f" precursor m/z: {round(df['precursor m/z'], 4)}" - fig = plot_ms_spectrum(df, title, "#2d3a9d") + + df_selected = pd.DataFrame( + { + "mz": df["mzarray"], + "intensity": df["intarray"], + } + ) + df_selected["RT"] = df["RT"] + df_selected["MS level"] = df["MS level"] + df_selected["precursor m/z"] = df["precursor m/z"] + df_selected["max intensity m/z"] = df["max intensity m/z"] + + fig = plot_ms_spectrum( + df_selected, + title, + st.session_state.spectrum_bin_peaks, + st.session_state.spectrum_num_bins, + ) + show_fig(fig, title.replace(" ", "_"), True, "view_spectrum_selection") else: st.session_state.pop("view_spectrum_selection") @@ -273,7 +294,10 @@ def view_bpc_tic(): "Base Peak Chromatogram (BPC)", True, key="view_bpc", help="Plot BPC." ) cols[2].checkbox( - "Extracted Ion Chromatogram (EIC/XIC)", True, key="view_eic", help="Plot extracted ion chromatogram with specified m/z." + "Extracted Ion Chromatogram (EIC/XIC)", + True, + key="view_eic", + help="Plot extracted ion chromatogram with specified m/z.", ) cols[3].text_input( "XIC m/z", @@ -283,9 +307,12 @@ def view_bpc_tic(): ) cols[4].number_input( "XIC ppm tolerance", - 0.1, 50.0, 10.0, 1.0, + 0.1, + 50.0, + 10.0, + 1.0, help="Tolerance for XIC calculation (ppm).", - key="view_eic_ppm" + key="view_eic_ppm", ) fig = plot_bpc_tic() show_fig(fig, f"BPC-TIC-{st.session_state.view_selected_file}")