Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major cleanup #29

Merged
merged 3 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions alphadia/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
import alphadia
from alphadia.extraction import processlogger



@click.group(
context_settings=dict(
help_option_names=['-h', '--help'],
Expand Down Expand Up @@ -174,6 +172,7 @@ def extract(**kwargs):
#config_update = eval(kwargs['config_update']) if kwargs['config_update'] else None

plan = Plan(
output_location,
files,
None,
None,
Expand All @@ -182,7 +181,6 @@ def extract(**kwargs):
plan.from_spec_lib_base(lib)

plan.run(
output_location,
keep_decoys = kwargs['keep_decoys'],
fdr = kwargs['fdr'],
figure_path = kwargs['figure_path'],
Expand Down
1,013 changes: 0 additions & 1,013 deletions alphadia/extraction/archive.py

This file was deleted.

1 change: 0 additions & 1 deletion alphadia/extraction/calibration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# native imports
import os
import logging
from unittest.mock import DEFAULT
import yaml
import typing
import pickle
Expand Down
111 changes: 86 additions & 25 deletions alphadia/extraction/planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,34 @@
from typing import Union, List, Dict, Tuple, Optional

logger = logging.getLogger()
if not 'progress' in dir(logger):
from alphadia.extraction import processlogger
processlogger.init_logging()
from alphadia.extraction import processlogger


# alphadia imports
from alphadia.extraction import data, plexscoring
from alphadia.extraction.calibration import CalibrationManager
from alphadia.extraction.scoring import fdr_correction, MS2ExtractionWorkflow
from alphadia.extraction import utils
from alphadia.extraction.quadrupole import SimpleQuadrupole
from alphadia.extraction.scoring import fdr_correction, channel_fdr_correction
from alphadia.extraction import utils, validate
from alphadia.extraction.hybridselection import HybridCandidateSelection, HybridCandidateConfig
import alphadia

# alpha family imports
import alphatims

import alphabase.psm_reader
import alphabase.peptide.precursor
from alphabase.peptide import fragment
from alphabase.spectral_library.flat import SpecLibFlat
from alphabase.spectral_library.base import SpecLibBase

# third party imports
import numpy as np
import pandas as pd
from matplotlib.style import library
import neptune.new as neptune
from neptune.new.types import File
import os, psutil

class Plan:

def __init__(self,
output_folder : str,
raw_file_list: List,
config_path : Union[str, None] = None,
config_update_path : Union[str, None] = None,
Expand All @@ -68,6 +63,10 @@ def __init__(self,

"""

self.output_folder = output_folder
processlogger.init_logging(self.output_folder)
logger = logging.getLogger()

logger.progress(' _ _ _ ___ ___ _ ')
logger.progress(' /_\ | |_ __| |_ __ _| \_ _| /_\ ')
logger.progress(' / _ \| | \'_ \\ \' \/ _` | |) | | / _ \ ')
Expand All @@ -80,7 +79,7 @@ def __init__(self,
# default config path is not defined in the function definition to account for for different path separators on different OS
if config_path is None:
# default yaml config location under /misc/config/config.yaml
config_path = os.path.join(os.path.dirname(__file__), '..','..','misc','config','default_new.yaml')
config_path = os.path.join(os.path.dirname(__file__), '..','..','misc','config','default.yaml')

# 1. load default config
with open(config_path, 'r') as f:
Expand Down Expand Up @@ -279,15 +278,6 @@ def from_spec_lib_flat(self, speclib_flat):
self.speclib.precursor_df = self.speclib.precursor_df.sort_values('elution_group_idx')
self.speclib.precursor_df = self.speclib.precursor_df.reset_index(drop=True)

if 'channel_filter' in self.config['library_loading']:
try:
channels = self.config['library_loading']['channel_filter'].split(',')
channels = [int(c) for c in channels]
self.speclib._precursor_df = self.speclib._precursor_df[self.speclib._precursor_df['channel'].isin(channels)]
logger.info(f'filtering for channels {channels}')
except:
logger.error(f'could not parse channel filter {self.config["library_loading"]["channel_filter"]}')

def log_library_stats(self):

logger.info(f'========= Library Stats =========')
Expand Down Expand Up @@ -449,7 +439,6 @@ def get_run_data(self):
yield raw.jitclass(), precursor_df, self.speclib.fragment_df

def run(self,
output_folder,
figure_path = None,
neptune_token = None,
neptune_tags = [],
Expand Down Expand Up @@ -477,8 +466,12 @@ def run(self,

workflow.calibration()

df = workflow.extraction(keep_decoys=keep_decoys)
df = workflow.extraction(keep_decoys = keep_decoys)
df = df[df['qval'] <= fdr]

if self.config['multiplexing']['multiplexed_quant']:
df = workflow.requantify(df)

df['run'] = raw_name
dataframes.append(df)

Expand All @@ -489,7 +482,7 @@ def run(self,
continue

out_df = pd.concat(dataframes)
out_df.to_csv(os.path.join(output_folder, f'alpha_psms.tsv'), sep='\t', index=False)
out_df.to_csv(os.path.join(self.output_folder, f'alpha_psms.tsv'), sep='\t', index=False)

class Workflow:
def __init__(
Expand All @@ -506,7 +499,16 @@ def __init__(
self.config = config
self.dia_data = dia_data
self.raw_name = precursors_flat.iloc[0]['raw_name']
self.precursors_flat = precursors_flat


if self.config["library_loading"]["channel_filter"] == '':
allowed_channels = precursors_flat['channel'].unique()
else:
allowed_channels = [int(c) for c in self.config["library_loading"]["channel_filter"].split(',')]
logger.progress(f'Applying channel filter using only: {allowed_channels}')

self.precursors_flat_raw = precursors_flat.copy()
self.precursors_flat = self.precursors_flat_raw[self.precursors_flat_raw['channel'].isin(allowed_channels)].copy()
self.fragments_flat = fragments_flat

self.figure_path = figure_path
Expand Down Expand Up @@ -831,7 +833,9 @@ def extract_batch(self, batch_df):

return features_df, fragments_df

def extraction(self, keep_decoys=False):
def extraction(
self,
keep_decoys=False):

if self.run is not None:
for key, value in self.progress.items():
Expand Down Expand Up @@ -865,3 +869,60 @@ def extraction(self, keep_decoys=False):
logger.progress(f'=== extraction finished, 0.05 FDR: {precursors_05:,}, 0.01 FDR: {precursors_01:,}, 0.001 FDR: {precursors_001:,} ===')

return precursor_df

def requantify(
self,
psm_df
):

self.calibration_manager.predict(self.precursors_flat_raw, 'precursor')
self.calibration_manager.predict(self.fragments_flat, 'fragment')

reference_candidates = plexscoring.candidate_features_to_candidates(psm_df)

if not 'multiplexing' in self.config:
raise ValueError('no multiplexing config found')

logger.progress(f'=== Multiplexing {len(reference_candidates):,} precursors ===')

original_channels = psm_df['channel'].unique().tolist()
logger.progress(f'original channels: {original_channels}')

reference_channel = self.config['multiplexing']['reference_channel']
logger.progress(f'reference channel: {reference_channel}')

target_channels = [int(c) for c in self.config['multiplexing']['target_channels'].split(',')]
logger.progress(f'target channels: {target_channels}')

decoy_channel = self.config['multiplexing']['decoy_channel']
logger.progress(f'decoy channel: {decoy_channel}')

channels = list(set(original_channels + [reference_channel] + target_channels + [decoy_channel]))
multiplexed_candidates = plexscoring.multiplex_candidates(reference_candidates, self.precursors_flat_raw, channels=channels)

channel_count_lib = self.precursors_flat_raw['channel'].value_counts()
channel_count_multiplexed = multiplexed_candidates['channel'].value_counts()
## log channels with less than 100 precursors
for channel in channels:
if channel not in channel_count_lib:
logger.warning(f'channel {channel} not found in library')
if channel not in channel_count_multiplexed:
logger.warning(f'channel {channel} could not be mapped to existing IDs.')

logger.progress(f'=== Requantifying {len(multiplexed_candidates):,} precursors ===')

config = plexscoring.CandidateConfig()
config.max_cardinality = 1
config.score_grouped = True
config.reference_channel = 0

multiplexed_scoring = plexscoring.CandidateScoring(
self.dia_data,
self.precursors_flat_raw,
self.fragments_flat,
config=config
)

multiplexed_features, fragments = multiplexed_scoring(multiplexed_candidates)

return channel_fdr_correction(multiplexed_features)
127 changes: 6 additions & 121 deletions alphadia/extraction/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import pandas as pd
from scipy.stats import gaussian_kde
from matplotlib import patches

def density_scatter(
x: typing.Union[np.ndarray, pd.Series, pd.DataFrame],
Expand Down Expand Up @@ -116,7 +117,7 @@ def _generate_slice_collection(

return slice_collection

from matplotlib import patches


def _plot_slice_collection(slice_collection, ax, alpha=0.5, **kwargs):

Expand Down Expand Up @@ -153,129 +154,13 @@ def plot_dia_cycle(cycle,ax=None, cmap_name='YlOrRd', **kwargs):
ax.set_xlabel('Quadrupole m/z')
ax.set_ylabel('Scan')

def plot_all_precursors(
dense_precursors,
qtf,
template,
isotope_intensity
):


n_precursors = qtf.shape[0]
n_isotopes = qtf.shape[1]
n_observations = qtf.shape[2]
n_scans = qtf.shape[3]

# figure parameters
n_cols = n_isotopes * 2 + 1
n_rows = n_observations
width_ratios = np.append(np.tile([2, 0.8], n_isotopes),[2])



scan_range = np.arange(n_scans)
observation_importance = calculate_observation_importance(
template,
)

# iterate over precursors
# each precursor will be a separate figure
for i_precursor in range(n_precursors):

v_min_dense = np.min(dense_precursors)
v_max_dense = np.max(dense_precursors)

v_min_template = np.min(template)
v_max_template = np.max(template)

fig, axs = plt.subplots(
n_rows,
n_cols,
figsize = (n_cols * 1, n_rows*2),
gridspec_kw = {'width_ratios': width_ratios},
sharey='row'
)
# expand axes if there is only one row
if len(axs.shape) == 1:
axs = axs.reshape(1, axs.shape[0])

# iterate over observations, observations will be rows
for i_observation in range(n_observations):

# iterate over isotopes, isotopes will be columns
for i_isotope in range(n_isotopes):

# each even column will be a dense precursor
i_dense = 2*i_isotope
# each odd column will be a qtf
i_qtf = 2*i_isotope+1

# as precursors and isotopes are stored in a flat array, we need to calculate the index
dense_p_index = i_precursor * n_isotopes + i_isotope

# plot dense precursor
axs[i_observation,i_dense].imshow(
dense_precursors[0,dense_p_index,0],
vmin=v_min_dense,
vmax=v_max_dense,
)
axs[0,i_dense].set_title(f'isotope {i_isotope}')
# add text with relative isotope intensity
axs[i_observation,i_dense].text(
0.05,
0.95,
f'{isotope_intensity[i_precursor,i_isotope]*100:.2f} %',
horizontalalignment='left',
verticalalignment='top',
transform=axs[i_observation,i_dense].transAxes,
color='white'
)

# plot qtf and weighted qtf
axs[i_observation,i_qtf].plot(
qtf[i_precursor,i_isotope,i_observation],
scan_range
)
axs[i_observation,i_qtf].plot(
qtf[i_precursor,i_isotope,i_observation] * isotope_intensity[i_precursor,i_isotope],
scan_range
)
axs[i_observation,i_qtf].set_xlim(0, 1)
axs[-1,i_dense].set_xlabel(f'frame')

# remove xticks from all but last row
if i_observation < n_observations - 1:
for ax in axs[i_observation,:].flat:
ax.set_xticks([])

# bold title
axs[0,-1].set_title(f'template', fontweight='bold')

axs[i_observation,-1].imshow(
template[i_precursor,i_observation],
vmin=v_min_template,
vmax=v_max_template,
)

axs[i_observation,-1].text(
0.05,
0.95,
f'{observation_importance[i_precursor,i_observation]*100:.2f} %',
horizontalalignment='left',
verticalalignment='top',
transform=axs[i_observation,-1].transAxes,
color='white'
)
axs[i_observation,0].set_ylabel(f'observation {i_observation}\nscan')

fig.tight_layout()
plt.show()

def plot_image_collection(
images
images: typing.List[np.ndarray],
image_width: float = 4,
image_height: float = 6
):
n_images = len(images)
fig, ax = plt.subplots(1, n_images, figsize=(n_images*4, 6))
fig, ax = plt.subplots(1, n_images, figsize=(n_images*image_width, image_height))

if n_images == 1:
ax = [ax]
Expand Down
Loading
Loading