From 6f84d1fb3a4eb3c4ccb0b8ea0c62655b6eee3525 Mon Sep 17 00:00:00 2001 From: Oscar Branson Date: Mon, 16 Sep 2024 17:44:20 +0100 Subject: [PATCH] clay removal filter --- latools/D_obj.py | 35 ++++++++++++++++++++++ latools/filtering/clay_removal.py | 50 +++++++++++++++++++++++++++++++ latools/latools.py | 32 ++++++++++++++++++-- 3 files changed, 114 insertions(+), 3 deletions(-) create mode 100644 latools/filtering/clay_removal.py diff --git a/latools/D_obj.py b/latools/D_obj.py index 7879f82..8767128 100644 --- a/latools/D_obj.py +++ b/latools/D_obj.py @@ -22,6 +22,7 @@ from .filtering import filters from .filtering import clustering +from .filtering.clay_removal import clay_removal from .filtering.filt_obj import filt from .filtering.signal_optimiser import signal_optimiser, optimisation_plot @@ -1194,6 +1195,40 @@ def correlation_plot(self, x_analyte, y_analyte, window=15, filt=True, recalc=Fa return fig, axs + @_log + def filter_clay_removal(self, clay_tracers=['27Al', '55Mn'], filt=True): + """ + Apply a filter to remove clay-rich data. + + Parameters + ---------- + clay_tracers : array-like + The analytes used to identify clay-rich data. Default is + ['27Al', '55Mn']. + + Returns + ------- + None + """ + if isinstance(clay_tracers, str): + clay_tracers = [clay_tracers] + if len(clay_tracers) < 2: + raise ValueError('Must provide at least two clay tracers.') + clay_tracers = list(self._analyte_checker(clay_tracers)) + + # get existing filter + ind = self.filt.grab_filt(filt=filt, analyte=clay_tracers[0]) + + # isolate data + data_dict = {k: un.nominal_values(self.focus[k][ind]) for k in clay_tracers} + + clay_filt = np.zeros_like(ind, dtype=bool) + + if np.any(ind): + clay_filt[ind] = clay_removal(data_dict) + + self.filt.add('clay-' + '-'.join([t.split('_')[0] for t in clay_tracers]), clay_filt, 'Clay-rich data filter.') + @_log def filter_new(self, name, filt_str): """ diff --git a/latools/filtering/clay_removal.py b/latools/filtering/clay_removal.py new file mode 100644 index 0000000..5880b68 --- /dev/null +++ b/latools/filtering/clay_removal.py @@ -0,0 +1,50 @@ +import numpy as np +import pandas as pd +from scipy.optimize import curve_fit +from scipy.stats import zscore + +def piecewise_linear(x, x0, y0, k1, k2): + return np.piecewise(x, [x < x0], [lambda x:k1*x + y0-k1*x0, lambda x:k2*x + y0-k2*x0]) + +def linear(x, m, c): + return m * x + c + +def clay_removal(data_dict): + + dat = pd.DataFrame.from_dict(data_dict) + clay_tracers = list(dat.columns) + + sub = dat.copy().dropna() + + # calculate clay score + sub['clay'] = zscore(sub).mean(axis=1) + + # sort by clay score + ssub = sub.dropna().sort_values('clay', ascending=True) + + # calculate cumulative mean + msub = ssub.cumsum() / np.arange(1, len(ssub)+1).reshape(-1,1) + + # fit piecewise linear + changepoints = [] + for c in clay_tracers: + mp, mcov = curve_fit(piecewise_linear, msub['clay'].values, msub[c].values, p0=[msub['clay'].mean(), 0, 0, 0]) + + # check that slopes are sufficiently different + if 0.8 < mp[-2] / mp[-1] < 1.2: + continue + + changepoints.append(mp[0]) + + if len(changepoints) == 0: + return np.ones_like(dat.index, dtype=bool) + + # calculate average changepoint + changepoint = np.mean(changepoints) + + ind = msub['clay'] < changepoint + + dat['filt'] = False + dat.loc[ind.index[ind], 'filt'] = True + + return dat['filt'].values \ No newline at end of file diff --git a/latools/latools.py b/latools/latools.py index 65855ba..059a2ef 100644 --- a/latools/latools.py +++ b/latools/latools.py @@ -762,7 +762,7 @@ def expfit(x, e): ax.plot(fitx, expfit(fitx, ep - nsd_below * np.diag(ecov)**.5, ), color='b', label='Used') ax.text(0.95, 0.75, - ('y = $e^{%.2f \pm %.2f * x}$\n$R^2$= %.2f \nCoefficient: ' + ('y = $e^{%.2f \\pm %.2f * x}$\n$R^2$= %.2f \nCoefficient: ' '%.2f') % (ep, np.diag(ecov)**.5, eeR2, @@ -2575,6 +2575,32 @@ def correlation_plots(self, x_analyte, y_analyte, window=15, filt=True, recalc=F prog.update() return + @_log + def filter_clay_removal(self, clay_tracers=['27Al', '55Mn'], filt=True, samples=None, subset=None): + """ + Apply a filter to remove clay-rich samples. + + Parameters + ---------- + clay_tracers : list + A list of analytes that are indicative of clay content. + filt : bool + Whether or not to apply existing filters to the data before + calculating this filter. + + Returns + ------- + None + """ + if samples is not None: + subset = self.make_subset(samples, silent=True) + samples = self._get_samples(subset) + + with self.pbar.set(total=len(samples), desc='Removing Clays') as prog: + for s in self.data.values(): + s.filter_clay_removal(clay_tracers, filt=filt) + prog.update() + @_log def filter_on(self, filt=None, analyte=None, samples=None, subset=None, show_status=False): """ @@ -4284,7 +4310,7 @@ def minimal_export(self, target_analytes=None, path=None): # format sample_stats correctly lss = [(i, l) for i, l in enumerate(self.log) if 'sample_stats' in l] - rep = re.compile("(.*'stats': )(\[.*?\])(.*)") + rep = re.compile(r"(.*'stats': )(\[.*?\])(.*)") for i, l in lss: self.log[i] = rep.sub(r'\1' + str(self.stats_calced) + r'\3', l) @@ -4354,7 +4380,7 @@ def reproduce(past_analysis, plotting=False, data_path=None, with open(paths['custom_stat_functions'], 'r') as f: csf = f.read() - fname = re.compile('def (.*)\(.*') + fname = re.compile(r'def (.*)\(.*') for c in csf.split('\n\n\n\n'): if fname.match(c):