Skip to content

Commit

Permalink
Merge pull request #29 from MannLabs/improved-logging
Browse files Browse the repository at this point in the history
Major cleanup
  • Loading branch information
GeorgWa committed Jul 19, 2023
2 parents 27e5907 + 5e94345 commit 2a6837d
Show file tree
Hide file tree
Showing 34 changed files with 337 additions and 43,124 deletions.
4 changes: 1 addition & 3 deletions alphadia/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
import alphadia
from alphadia.extraction import processlogger



@click.group(
context_settings=dict(
help_option_names=['-h', '--help'],
Expand Down Expand Up @@ -174,6 +172,7 @@ def extract(**kwargs):
#config_update = eval(kwargs['config_update']) if kwargs['config_update'] else None

plan = Plan(
output_location,
files,
None,
None,
Expand All @@ -182,7 +181,6 @@ def extract(**kwargs):
plan.from_spec_lib_base(lib)

plan.run(
output_location,
keep_decoys = kwargs['keep_decoys'],
fdr = kwargs['fdr'],
figure_path = kwargs['figure_path'],
Expand Down
1,013 changes: 0 additions & 1,013 deletions alphadia/extraction/archive.py

This file was deleted.

1 change: 0 additions & 1 deletion alphadia/extraction/calibration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# native imports
import os
import logging
from unittest.mock import DEFAULT
import yaml
import typing
import pickle
Expand Down
111 changes: 86 additions & 25 deletions alphadia/extraction/planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,34 @@
from typing import Union, List, Dict, Tuple, Optional

logger = logging.getLogger()
if not 'progress' in dir(logger):
from alphadia.extraction import processlogger
processlogger.init_logging()
from alphadia.extraction import processlogger


# alphadia imports
from alphadia.extraction import data, plexscoring
from alphadia.extraction.calibration import CalibrationManager
from alphadia.extraction.scoring import fdr_correction, MS2ExtractionWorkflow
from alphadia.extraction import utils
from alphadia.extraction.quadrupole import SimpleQuadrupole
from alphadia.extraction.scoring import fdr_correction, channel_fdr_correction
from alphadia.extraction import utils, validate
from alphadia.extraction.hybridselection import HybridCandidateSelection, HybridCandidateConfig
import alphadia

# alpha family imports
import alphatims

import alphabase.psm_reader
import alphabase.peptide.precursor
from alphabase.peptide import fragment
from alphabase.spectral_library.flat import SpecLibFlat
from alphabase.spectral_library.base import SpecLibBase

# third party imports
import numpy as np
import pandas as pd
from matplotlib.style import library
import neptune.new as neptune
from neptune.new.types import File
import os, psutil

class Plan:

def __init__(self,
output_folder : str,
raw_file_list: List,
config_path : Union[str, None] = None,
config_update_path : Union[str, None] = None,
Expand All @@ -68,6 +63,10 @@ def __init__(self,
"""

self.output_folder = output_folder
processlogger.init_logging(self.output_folder)
logger = logging.getLogger()

logger.progress(' _ _ _ ___ ___ _ ')
logger.progress(' /_\ | |_ __| |_ __ _| \_ _| /_\ ')
logger.progress(' / _ \| | \'_ \\ \' \/ _` | |) | | / _ \ ')
Expand All @@ -80,7 +79,7 @@ def __init__(self,
# default config path is not defined in the function definition to account for for different path separators on different OS
if config_path is None:
# default yaml config location under /misc/config/config.yaml
config_path = os.path.join(os.path.dirname(__file__), '..','..','misc','config','default_new.yaml')
config_path = os.path.join(os.path.dirname(__file__), '..','..','misc','config','default.yaml')

# 1. load default config
with open(config_path, 'r') as f:
Expand Down Expand Up @@ -279,15 +278,6 @@ def from_spec_lib_flat(self, speclib_flat):
self.speclib.precursor_df = self.speclib.precursor_df.sort_values('elution_group_idx')
self.speclib.precursor_df = self.speclib.precursor_df.reset_index(drop=True)

if 'channel_filter' in self.config['library_loading']:
try:
channels = self.config['library_loading']['channel_filter'].split(',')
channels = [int(c) for c in channels]
self.speclib._precursor_df = self.speclib._precursor_df[self.speclib._precursor_df['channel'].isin(channels)]
logger.info(f'filtering for channels {channels}')
except:
logger.error(f'could not parse channel filter {self.config["library_loading"]["channel_filter"]}')

def log_library_stats(self):

logger.info(f'========= Library Stats =========')
Expand Down Expand Up @@ -449,7 +439,6 @@ def get_run_data(self):
yield raw.jitclass(), precursor_df, self.speclib.fragment_df

def run(self,
output_folder,
figure_path = None,
neptune_token = None,
neptune_tags = [],
Expand Down Expand Up @@ -477,8 +466,12 @@ def run(self,

workflow.calibration()

df = workflow.extraction(keep_decoys=keep_decoys)
df = workflow.extraction(keep_decoys = keep_decoys)
df = df[df['qval'] <= fdr]

if self.config['multiplexing']['multiplexed_quant']:
df = workflow.requantify(df)

df['run'] = raw_name
dataframes.append(df)

Expand All @@ -489,7 +482,7 @@ def run(self,
continue

out_df = pd.concat(dataframes)
out_df.to_csv(os.path.join(output_folder, f'alpha_psms.tsv'), sep='\t', index=False)
out_df.to_csv(os.path.join(self.output_folder, f'alpha_psms.tsv'), sep='\t', index=False)

class Workflow:
def __init__(
Expand All @@ -506,7 +499,16 @@ def __init__(
self.config = config
self.dia_data = dia_data
self.raw_name = precursors_flat.iloc[0]['raw_name']
self.precursors_flat = precursors_flat


if self.config["library_loading"]["channel_filter"] == '':
allowed_channels = precursors_flat['channel'].unique()
else:
allowed_channels = [int(c) for c in self.config["library_loading"]["channel_filter"].split(',')]
logger.progress(f'Applying channel filter using only: {allowed_channels}')

self.precursors_flat_raw = precursors_flat.copy()
self.precursors_flat = self.precursors_flat_raw[self.precursors_flat_raw['channel'].isin(allowed_channels)].copy()
self.fragments_flat = fragments_flat

self.figure_path = figure_path
Expand Down Expand Up @@ -831,7 +833,9 @@ def extract_batch(self, batch_df):

return features_df, fragments_df

def extraction(self, keep_decoys=False):
def extraction(
self,
keep_decoys=False):

if self.run is not None:
for key, value in self.progress.items():
Expand Down Expand Up @@ -865,3 +869,60 @@ def extraction(self, keep_decoys=False):
logger.progress(f'=== extraction finished, 0.05 FDR: {precursors_05:,}, 0.01 FDR: {precursors_01:,}, 0.001 FDR: {precursors_001:,} ===')

return precursor_df

def requantify(
self,
psm_df
):

self.calibration_manager.predict(self.precursors_flat_raw, 'precursor')
self.calibration_manager.predict(self.fragments_flat, 'fragment')

reference_candidates = plexscoring.candidate_features_to_candidates(psm_df)

if not 'multiplexing' in self.config:
raise ValueError('no multiplexing config found')

logger.progress(f'=== Multiplexing {len(reference_candidates):,} precursors ===')

original_channels = psm_df['channel'].unique().tolist()
logger.progress(f'original channels: {original_channels}')

reference_channel = self.config['multiplexing']['reference_channel']
logger.progress(f'reference channel: {reference_channel}')

target_channels = [int(c) for c in self.config['multiplexing']['target_channels'].split(',')]
logger.progress(f'target channels: {target_channels}')

decoy_channel = self.config['multiplexing']['decoy_channel']
logger.progress(f'decoy channel: {decoy_channel}')

channels = list(set(original_channels + [reference_channel] + target_channels + [decoy_channel]))
multiplexed_candidates = plexscoring.multiplex_candidates(reference_candidates, self.precursors_flat_raw, channels=channels)

channel_count_lib = self.precursors_flat_raw['channel'].value_counts()
channel_count_multiplexed = multiplexed_candidates['channel'].value_counts()
## log channels with less than 100 precursors
for channel in channels:
if channel not in channel_count_lib:
logger.warning(f'channel {channel} not found in library')
if channel not in channel_count_multiplexed:
logger.warning(f'channel {channel} could not be mapped to existing IDs.')

logger.progress(f'=== Requantifying {len(multiplexed_candidates):,} precursors ===')

config = plexscoring.CandidateConfig()
config.max_cardinality = 1
config.score_grouped = True
config.reference_channel = 0

multiplexed_scoring = plexscoring.CandidateScoring(
self.dia_data,
self.precursors_flat_raw,
self.fragments_flat,
config=config
)

multiplexed_features, fragments = multiplexed_scoring(multiplexed_candidates)

return channel_fdr_correction(multiplexed_features)
127 changes: 6 additions & 121 deletions alphadia/extraction/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import pandas as pd
from scipy.stats import gaussian_kde
from matplotlib import patches

def density_scatter(
x: typing.Union[np.ndarray, pd.Series, pd.DataFrame],
Expand Down Expand Up @@ -116,7 +117,7 @@ def _generate_slice_collection(

return slice_collection

from matplotlib import patches


def _plot_slice_collection(slice_collection, ax, alpha=0.5, **kwargs):

Expand Down Expand Up @@ -153,129 +154,13 @@ def plot_dia_cycle(cycle,ax=None, cmap_name='YlOrRd', **kwargs):
ax.set_xlabel('Quadrupole m/z')
ax.set_ylabel('Scan')

def plot_all_precursors(
dense_precursors,
qtf,
template,
isotope_intensity
):


n_precursors = qtf.shape[0]
n_isotopes = qtf.shape[1]
n_observations = qtf.shape[2]
n_scans = qtf.shape[3]

# figure parameters
n_cols = n_isotopes * 2 + 1
n_rows = n_observations
width_ratios = np.append(np.tile([2, 0.8], n_isotopes),[2])



scan_range = np.arange(n_scans)
observation_importance = calculate_observation_importance(
template,
)

# iterate over precursors
# each precursor will be a separate figure
for i_precursor in range(n_precursors):

v_min_dense = np.min(dense_precursors)
v_max_dense = np.max(dense_precursors)

v_min_template = np.min(template)
v_max_template = np.max(template)

fig, axs = plt.subplots(
n_rows,
n_cols,
figsize = (n_cols * 1, n_rows*2),
gridspec_kw = {'width_ratios': width_ratios},
sharey='row'
)
# expand axes if there is only one row
if len(axs.shape) == 1:
axs = axs.reshape(1, axs.shape[0])

# iterate over observations, observations will be rows
for i_observation in range(n_observations):

# iterate over isotopes, isotopes will be columns
for i_isotope in range(n_isotopes):

# each even column will be a dense precursor
i_dense = 2*i_isotope
# each odd column will be a qtf
i_qtf = 2*i_isotope+1

# as precursors and isotopes are stored in a flat array, we need to calculate the index
dense_p_index = i_precursor * n_isotopes + i_isotope

# plot dense precursor
axs[i_observation,i_dense].imshow(
dense_precursors[0,dense_p_index,0],
vmin=v_min_dense,
vmax=v_max_dense,
)
axs[0,i_dense].set_title(f'isotope {i_isotope}')
# add text with relative isotope intensity
axs[i_observation,i_dense].text(
0.05,
0.95,
f'{isotope_intensity[i_precursor,i_isotope]*100:.2f} %',
horizontalalignment='left',
verticalalignment='top',
transform=axs[i_observation,i_dense].transAxes,
color='white'
)

# plot qtf and weighted qtf
axs[i_observation,i_qtf].plot(
qtf[i_precursor,i_isotope,i_observation],
scan_range
)
axs[i_observation,i_qtf].plot(
qtf[i_precursor,i_isotope,i_observation] * isotope_intensity[i_precursor,i_isotope],
scan_range
)
axs[i_observation,i_qtf].set_xlim(0, 1)
axs[-1,i_dense].set_xlabel(f'frame')

# remove xticks from all but last row
if i_observation < n_observations - 1:
for ax in axs[i_observation,:].flat:
ax.set_xticks([])

# bold title
axs[0,-1].set_title(f'template', fontweight='bold')

axs[i_observation,-1].imshow(
template[i_precursor,i_observation],
vmin=v_min_template,
vmax=v_max_template,
)

axs[i_observation,-1].text(
0.05,
0.95,
f'{observation_importance[i_precursor,i_observation]*100:.2f} %',
horizontalalignment='left',
verticalalignment='top',
transform=axs[i_observation,-1].transAxes,
color='white'
)
axs[i_observation,0].set_ylabel(f'observation {i_observation}\nscan')

fig.tight_layout()
plt.show()

def plot_image_collection(
images
images: typing.List[np.ndarray],
image_width: float = 4,
image_height: float = 6
):
n_images = len(images)
fig, ax = plt.subplots(1, n_images, figsize=(n_images*4, 6))
fig, ax = plt.subplots(1, n_images, figsize=(n_images*image_width, image_height))

if n_images == 1:
ax = [ax]
Expand Down
Loading

0 comments on commit 2a6837d

Please sign in to comment.