Skip to content

Commit

Permalink
clay removal filter
Browse files Browse the repository at this point in the history
  • Loading branch information
oscarbranson committed Sep 16, 2024
1 parent 1146d16 commit 6f84d1f
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 3 deletions.
35 changes: 35 additions & 0 deletions latools/D_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .filtering import filters
from .filtering import clustering
from .filtering.clay_removal import clay_removal
from .filtering.filt_obj import filt
from .filtering.signal_optimiser import signal_optimiser, optimisation_plot

Expand Down Expand Up @@ -1194,6 +1195,40 @@ def correlation_plot(self, x_analyte, y_analyte, window=15, filt=True, recalc=Fa

return fig, axs

@_log
def filter_clay_removal(self, clay_tracers=['27Al', '55Mn'], filt=True):
"""
Apply a filter to remove clay-rich data.
Parameters
----------
clay_tracers : array-like
The analytes used to identify clay-rich data. Default is
['27Al', '55Mn'].
Returns
-------
None
"""
if isinstance(clay_tracers, str):
clay_tracers = [clay_tracers]
if len(clay_tracers) < 2:
raise ValueError('Must provide at least two clay tracers.')
clay_tracers = list(self._analyte_checker(clay_tracers))

# get existing filter
ind = self.filt.grab_filt(filt=filt, analyte=clay_tracers[0])

# isolate data
data_dict = {k: un.nominal_values(self.focus[k][ind]) for k in clay_tracers}

clay_filt = np.zeros_like(ind, dtype=bool)

if np.any(ind):
clay_filt[ind] = clay_removal(data_dict)

self.filt.add('clay-' + '-'.join([t.split('_')[0] for t in clay_tracers]), clay_filt, 'Clay-rich data filter.')

@_log
def filter_new(self, name, filt_str):
"""
Expand Down
50 changes: 50 additions & 0 deletions latools/filtering/clay_removal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit
from scipy.stats import zscore

def piecewise_linear(x, x0, y0, k1, k2):
return np.piecewise(x, [x < x0], [lambda x:k1*x + y0-k1*x0, lambda x:k2*x + y0-k2*x0])

def linear(x, m, c):
return m * x + c

def clay_removal(data_dict):

dat = pd.DataFrame.from_dict(data_dict)
clay_tracers = list(dat.columns)

sub = dat.copy().dropna()

# calculate clay score
sub['clay'] = zscore(sub).mean(axis=1)

# sort by clay score
ssub = sub.dropna().sort_values('clay', ascending=True)

# calculate cumulative mean
msub = ssub.cumsum() / np.arange(1, len(ssub)+1).reshape(-1,1)

# fit piecewise linear
changepoints = []
for c in clay_tracers:
mp, mcov = curve_fit(piecewise_linear, msub['clay'].values, msub[c].values, p0=[msub['clay'].mean(), 0, 0, 0])

# check that slopes are sufficiently different
if 0.8 < mp[-2] / mp[-1] < 1.2:
continue

changepoints.append(mp[0])

if len(changepoints) == 0:
return np.ones_like(dat.index, dtype=bool)

# calculate average changepoint
changepoint = np.mean(changepoints)

ind = msub['clay'] < changepoint

dat['filt'] = False
dat.loc[ind.index[ind], 'filt'] = True

return dat['filt'].values
32 changes: 29 additions & 3 deletions latools/latools.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ def expfit(x, e):
ax.plot(fitx, expfit(fitx, ep - nsd_below * np.diag(ecov)**.5, ),
color='b', label='Used')
ax.text(0.95, 0.75,
('y = $e^{%.2f \pm %.2f * x}$\n$R^2$= %.2f \nCoefficient: '
('y = $e^{%.2f \\pm %.2f * x}$\n$R^2$= %.2f \nCoefficient: '
'%.2f') % (ep,
np.diag(ecov)**.5,
eeR2,
Expand Down Expand Up @@ -2575,6 +2575,32 @@ def correlation_plots(self, x_analyte, y_analyte, window=15, filt=True, recalc=F
prog.update()
return

@_log
def filter_clay_removal(self, clay_tracers=['27Al', '55Mn'], filt=True, samples=None, subset=None):
"""
Apply a filter to remove clay-rich samples.
Parameters
----------
clay_tracers : list
A list of analytes that are indicative of clay content.
filt : bool
Whether or not to apply existing filters to the data before
calculating this filter.
Returns
-------
None
"""
if samples is not None:
subset = self.make_subset(samples, silent=True)
samples = self._get_samples(subset)

with self.pbar.set(total=len(samples), desc='Removing Clays') as prog:
for s in self.data.values():
s.filter_clay_removal(clay_tracers, filt=filt)
prog.update()

@_log
def filter_on(self, filt=None, analyte=None, samples=None, subset=None, show_status=False):
"""
Expand Down Expand Up @@ -4284,7 +4310,7 @@ def minimal_export(self, target_analytes=None, path=None):

# format sample_stats correctly
lss = [(i, l) for i, l in enumerate(self.log) if 'sample_stats' in l]
rep = re.compile("(.*'stats': )(\[.*?\])(.*)")
rep = re.compile(r"(.*'stats': )(\[.*?\])(.*)")
for i, l in lss:
self.log[i] = rep.sub(r'\1' + str(self.stats_calced) + r'\3', l)

Expand Down Expand Up @@ -4354,7 +4380,7 @@ def reproduce(past_analysis, plotting=False, data_path=None,
with open(paths['custom_stat_functions'], 'r') as f:
csf = f.read()

fname = re.compile('def (.*)\(.*')
fname = re.compile(r'def (.*)\(.*')

for c in csf.split('\n\n\n\n'):
if fname.match(c):
Expand Down

0 comments on commit 6f84d1f

Please sign in to comment.