From 1eecbcf2e65daaf4d4de54e93fe79b22f913cc58 Mon Sep 17 00:00:00 2001 From: Nathaniel Sherry Date: Mon, 18 Mar 2024 20:57:23 -0400 Subject: [PATCH] Move some values for fitting solvers into FittingSolverContext to allow better reuse --- .../curve/fitting/solver/FittingSolver.java | 101 ++++++++++++- .../fitting/solver/GreedyFittingSolver.java | 19 +-- .../MultisamplingOptimizingFittingSolver.java | 39 ++--- .../solver/OptimizingFittingSolver.java | 138 ++++++------------ 4 files changed, 164 insertions(+), 133 deletions(-) diff --git a/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/FittingSolver.java b/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/FittingSolver.java index 8766e498d..a862acb1e 100644 --- a/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/FittingSolver.java +++ b/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/FittingSolver.java @@ -1,10 +1,18 @@ package org.peakaboo.curvefit.curve.fitting.solver; +import java.util.ArrayList; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; + import org.peakaboo.curvefit.curve.fitting.Curve; +import org.peakaboo.curvefit.curve.fitting.CurveView; import org.peakaboo.curvefit.curve.fitting.FittingResultSetView; import org.peakaboo.curvefit.curve.fitting.FittingSet; import org.peakaboo.curvefit.curve.fitting.FittingSetView; import org.peakaboo.curvefit.curve.fitting.fitter.CurveFitter; +import org.peakaboo.curvefit.peak.table.Element; +import org.peakaboo.curvefit.peak.transition.TransitionShell; import org.peakaboo.framework.bolt.plugin.java.BoltJavaPlugin; import org.peakaboo.framework.cyclops.spectrum.Spectrum; import org.peakaboo.framework.cyclops.spectrum.SpectrumView; @@ -16,8 +24,99 @@ */ public interface FittingSolver extends BoltJavaPlugin { - public static record FittingSolverContext (SpectrumView data, FittingSetView fittings, CurveFitter fitter) {}; + public static class FittingSolverContext { + + // Given Values + /** + * The spectrum for which we're performing fit solving + */ + public SpectrumView data; + + /** + * The fittings to be solved + */ + public FittingSetView fittings; + + /** + * The curve fitter which will perform single-curve fitting + */ + public CurveFitter fitter; + + + + // Derived Values + + /** + * Sorted list of channels + */ + public List curves; + + /** + * Subset of channels to focus on + */ + int[] channels; + + public FittingSolverContext(SpectrumView data, FittingSetView fittings, CurveFitter fitter) { + this.data = data; + this.fittings = fittings; + this.fitter = fitter; + + // Generate sorted list of curves from visible fittings + curves = new ArrayList<>(fittings.getVisibleCurves()); + sortCurves(curves); + + // Calculate a list of channels with enough curve signal to matter + channels = getIntenseChannels(curves); + } + + /** + * Performs a shallow copy of another {@link FittingSolverContext} + */ + public FittingSolverContext(FittingSolverContext copy) { + this.data = copy.data; + this.fittings = copy.fittings; + this.fitter = copy.fitter; + this.curves = copy.curves; + this.channels = copy.channels; + } + + } FittingResultSetView solve(FittingSolverContext ctx); + + /** + * Given a list of curves, sort them by by shell first, and then by element + */ + static void sortCurves(List curves) { + curves.sort((a, b) -> { + TransitionShell as, bs; + as = a.getTransitionSeries().getShell(); + bs = b.getTransitionSeries().getShell(); + Element ae, be; + ae = a.getTransitionSeries().getElement(); + be = b.getTransitionSeries().getElement(); + if (as.equals(bs)) { + return ae.compareTo(be); + } else { + return as.compareTo(bs); + } + }); + } + + static int[] getIntenseChannels(List curves) { + Set intenseChannels = new LinkedHashSet<>(); + for (CurveView curve : curves) { + intenseChannels.addAll(curve.getIntenseChannels()); + } + List asList = new ArrayList<>(intenseChannels); + asList.sort(Integer::compare); + int[] asArr = new int[asList.size()]; + for (int i = 0; i < asArr.length; i++) { + asArr[i] = asList.get(i); + } + return asArr; + } + + } diff --git a/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/GreedyFittingSolver.java b/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/GreedyFittingSolver.java index ec827afb7..ba2b08ba2 100644 --- a/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/GreedyFittingSolver.java +++ b/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/GreedyFittingSolver.java @@ -9,11 +9,8 @@ import org.peakaboo.curvefit.curve.fitting.FittingResultSet; import org.peakaboo.curvefit.curve.fitting.FittingResultSetView; import org.peakaboo.curvefit.curve.fitting.FittingResultView; -import org.peakaboo.curvefit.curve.fitting.FittingSetView; -import org.peakaboo.curvefit.curve.fitting.fitter.CurveFitter; import org.peakaboo.curvefit.curve.fitting.fitter.CurveFitter.CurveFitterContext; import org.peakaboo.framework.cyclops.spectrum.ArraySpectrum; -import org.peakaboo.framework.cyclops.spectrum.SpectrumView; import org.peakaboo.framework.cyclops.spectrum.Spectrum; import org.peakaboo.framework.cyclops.spectrum.SpectrumCalculations; @@ -49,23 +46,19 @@ public String pluginUUID() { */ @Override public FittingResultSetView solve(FittingSolverContext ctx) { - - SpectrumView data = ctx.data(); - FittingSetView fittings = ctx.fittings(); - CurveFitter fitter = ctx.fitter(); - Spectrum resultTotalFit = new ArraySpectrum(data.size()); + Spectrum resultTotalFit = new ArraySpectrum(ctx.data.size()); List resultFits = new ArrayList<>(); - FittingParametersView resultParameters = fittings.getFittingParameters().copy(); + FittingParametersView resultParameters = ctx.fittings.getFittingParameters().copy(); - Spectrum remainder = new ArraySpectrum(data); - Spectrum scaled = new ArraySpectrum(data.size()); + Spectrum remainder = new ArraySpectrum(ctx.data); + Spectrum scaled = new ArraySpectrum(ctx.data.size()); // calculate the curves - for (CurveView curve : fittings.getCurves()) { + for (CurveView curve : ctx.fittings.getCurves()) { if (!curve.getTransitionSeries().isVisible()) { continue; } - FittingResult result = fitter.fit(new CurveFitterContext(remainder, curve)); + FittingResult result = ctx.fitter.fit(new CurveFitterContext(remainder, curve)); curve.scaleInto(result.getCurveScale(), scaled); SpectrumCalculations.subtractLists_inplace(remainder, scaled, 0.0f); diff --git a/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/MultisamplingOptimizingFittingSolver.java b/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/MultisamplingOptimizingFittingSolver.java index c3c4424f9..01840087d 100644 --- a/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/MultisamplingOptimizingFittingSolver.java +++ b/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/MultisamplingOptimizingFittingSolver.java @@ -2,16 +2,12 @@ import java.util.ArrayList; import java.util.Collections; -import java.util.List; import java.util.Random; import org.apache.commons.math3.analysis.MultivariateFunction; import org.apache.commons.math3.optim.PointValuePair; import org.peakaboo.curvefit.curve.fitting.CurveView; import org.peakaboo.curvefit.curve.fitting.FittingResultSetView; -import org.peakaboo.curvefit.curve.fitting.FittingSetView; -import org.peakaboo.curvefit.curve.fitting.fitter.CurveFitter; -import org.peakaboo.framework.cyclops.spectrum.SpectrumView; public class MultisamplingOptimizingFittingSolver extends OptimizingFittingSolver { @@ -41,31 +37,27 @@ public String pluginUUID() { } @Override - public FittingResultSetView solve(FittingSolverContext ctx) { + public FittingResultSetView solve(FittingSolverContext inputCtx) { - SpectrumView data = ctx.data(); - FittingSetView fittings = ctx.fittings(); - CurveFitter fitter = ctx.fitter(); - - int size = fittings.getVisibleCurves().size(); + int size = inputCtx.fittings.getVisibleCurves().size(); if (size == 0) { - return getEmptyResult(data, fittings); + return getEmptyResult(inputCtx); } - List curves = fittings.getVisibleCurves(); - sortCurves(curves); - int[] intenseChannels = getIntenseChannels(curves); + // Create a shallow copy of the input context and then make a deep copy of the + // curve list so that we can permute it without impacting the original + FittingSolverContext permCtx = new FittingSolverContext(inputCtx); + permCtx.curves = new ArrayList<>(permCtx.curves); - List perm = new ArrayList<>(curves); int counter = 0; double[] scalings = new double[size]; + EvaluationSpace eval = new EvaluationSpace(permCtx.data.size()); while (counter <= 10) { - Collections.shuffle(perm, new Random(12345654321l)); + Collections.shuffle(permCtx.curves, new Random(12345654321l)); - double[] guess = getInitialGuess(perm, fitter, data); - EvaluationContext context = new EvaluationContext(data, fittings, perm); - MultivariateFunction cost = getCostFunction(context, intenseChannels); + double[] guess = getInitialGuess(permCtx); + MultivariateFunction cost = getCostFunction(permCtx, eval); PointValuePair result = optimizeCostFunction(cost, guess, 0.02d); double[] permScalings = result.getPoint(); @@ -75,8 +67,9 @@ public FittingResultSetView solve(FittingSolverContext ctx) { //guess = permScalings; for (int i = 0; i < scalings.length; i++) { - CurveView c = perm.get(i); - int j = curves.indexOf(c); + // Map the scores back to the original permutation of the curves list + CurveView c = permCtx.curves.get(i); + int j = inputCtx.curves.indexOf(c); scalings[j] += permScalings[i]; } counter++; @@ -86,10 +79,8 @@ public FittingResultSetView solve(FittingSolverContext ctx) { for (int i = 0; i < scalings.length; i++) { scalings[i] /= counter; } - - EvaluationContext context = new EvaluationContext(data, fittings, curves); - return evaluate(scalings, context); + return evaluate(scalings, inputCtx); } diff --git a/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/OptimizingFittingSolver.java b/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/OptimizingFittingSolver.java index 678ab926a..839e4468a 100644 --- a/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/OptimizingFittingSolver.java +++ b/LibPeakaboo/src/main/java/org/peakaboo/curvefit/curve/fitting/solver/OptimizingFittingSolver.java @@ -2,9 +2,7 @@ import java.util.ArrayList; import java.util.Collections; -import java.util.LinkedHashSet; import java.util.List; -import java.util.Set; import org.apache.commons.math3.analysis.MultivariateFunction; import org.apache.commons.math3.optim.InitialGuess; @@ -20,15 +18,10 @@ import org.peakaboo.curvefit.curve.fitting.FittingResultSet; import org.peakaboo.curvefit.curve.fitting.FittingResultSetView; import org.peakaboo.curvefit.curve.fitting.FittingResultView; -import org.peakaboo.curvefit.curve.fitting.FittingSetView; -import org.peakaboo.curvefit.curve.fitting.fitter.CurveFitter; import org.peakaboo.curvefit.curve.fitting.fitter.CurveFitter.CurveFitterContext; -import org.peakaboo.curvefit.peak.table.Element; -import org.peakaboo.curvefit.peak.transition.TransitionShell; import org.peakaboo.framework.cyclops.spectrum.ArraySpectrum; import org.peakaboo.framework.cyclops.spectrum.Spectrum; import org.peakaboo.framework.cyclops.spectrum.SpectrumCalculations; -import org.peakaboo.framework.cyclops.spectrum.SpectrumView; public class OptimizingFittingSolver implements FittingSolver { @@ -62,38 +55,28 @@ public String pluginUUID() { @Override public FittingResultSetView solve(FittingSolverContext ctx) { - SpectrumView data = ctx.data(); - FittingSetView fittings = ctx.fittings(); - CurveFitter fitter = ctx.fitter(); - - int size = fittings.getVisibleCurves().size(); + int size = ctx.fittings.getVisibleCurves().size(); if (size == 0) { - return getEmptyResult(data, fittings); + return getEmptyResult(ctx); } - - List curves = new ArrayList<>(fittings.getVisibleCurves()); - sortCurves(curves); - int[] intenseChannels = getIntenseChannels(curves); - EvaluationContext context = new EvaluationContext(data, fittings, curves); - MultivariateFunction cost = getCostFunction(context, intenseChannels); - double[] guess = getInitialGuess(curves, fitter, data); - + + EvaluationSpace eval = new EvaluationSpace(ctx.data.size()); + MultivariateFunction cost = getCostFunction(ctx, eval); + double[] guess = getInitialGuess(ctx); PointValuePair result = optimizeCostFunction(cost, guess, costFnPrecision); - double[] scalings = result.getPoint(); - return evaluate(scalings, context); - + return evaluate(scalings, ctx); } - protected FittingResultSet getEmptyResult(SpectrumView data, FittingSetView fittings) { + protected FittingResultSet getEmptyResult(FittingSolverContext ctx) { return new FittingResultSet( - new ArraySpectrum(data.size()), - new ArraySpectrum(data), + new ArraySpectrum(ctx.data.size()), + new ArraySpectrum(ctx.data), Collections.emptyList(), - fittings.getFittingParameters().copy() + ctx.fittings.getFittingParameters().copy() ); } @@ -128,12 +111,12 @@ protected PointValuePair optimizeCostFunction(MultivariateFunction cost, double[ } - protected double[] getInitialGuess(List curves, CurveFitter fitter, SpectrumView data) { - int curveCount = curves.size(); + public static double[] getInitialGuess(FittingSolverContext ctx) { + int curveCount = ctx.curves.size(); double[] guess = new double[curveCount]; for (int i = 0; i < curveCount; i++) { - CurveView curve = curves.get(i); - FittingResultView guessFittingResult = fitter.fit(new CurveFitterContext(data, curve)); + CurveView curve = ctx.curves.get(i); + FittingResultView guessFittingResult = ctx.fitter.fit(new CurveFitterContext(ctx.data, curve)); //there will usually be some overlap between elements, so //we use 80% of the independently fitted guess. @@ -147,22 +130,10 @@ protected double[] getInitialGuess(List curves, CurveFitter fitter, S return guess; } - protected int[] getIntenseChannels(List curves) { - Set intenseChannels = new LinkedHashSet<>(); - for (CurveView curve : curves) { - intenseChannels.addAll(curve.getIntenseChannels()); - } - List asList = new ArrayList<>(intenseChannels); - asList.sort(Integer::compare); - int[] asArr = new int[asList.size()]; - for (int i = 0; i < asArr.length; i++) { - asArr[i] = asList.get(i); - } - return asArr; - } + - protected MultivariateFunction getCostFunction(EvaluationContext context, int[] intenseChannels) { + protected MultivariateFunction getCostFunction(FittingSolverContext ctx, EvaluationSpace eval) { return new MultivariateFunction() { @Override @@ -176,8 +147,8 @@ public double value(double[] point) { } } - test(point, intenseChannels, context); - float score = score(point, intenseChannels, context.residual); + test(point, ctx, eval); + float score = score(point, ctx, eval.residual); if (containsNegatives > 0) { return score * (1f+containsNegatives); } @@ -188,46 +159,29 @@ public double value(double[] point) { } - /** - * Given a list of curves, sort them by by shell first, and then by element - */ - protected void sortCurves(List curves) { - curves.sort((a, b) -> { - TransitionShell as, bs; - as = a.getTransitionSeries().getShell(); - bs = b.getTransitionSeries().getShell(); - Element ae, be; - ae = a.getTransitionSeries().getElement(); - be = b.getTransitionSeries().getElement(); - if (as.equals(bs)) { - return ae.compareTo(be); - } else { - return as.compareTo(bs); - } - }); - } + /** * Calculate the residual from data (signal) and total (fittings). Store the result in residual */ - private void test(double[] weights, int[] channels, EvaluationContext context) { - Spectrum total = context.total; + private void test(double[] weights, FittingSolverContext ctx, EvaluationSpace eval) { + Spectrum total = eval.total; total.zero(); //When there are no intense channels to consider, the residual will be equal to the data - if (channels.length == 0) { - SpectrumCalculations.subtractFromList_target(context.data, context.residual, 0f); + if (ctx.channels.length == 0) { + SpectrumCalculations.subtractFromList_target(ctx.data, eval.residual, 0f); return; } - List curves = context.curves; + List curves = ctx.curves; int curvesLength = weights.length; - int first = channels[0]; - int last = channels[channels.length-1]; + int first = ctx.channels[0]; + int last = ctx.channels[ctx.channels.length-1]; for (int i = 0; i < curvesLength; i++) { - curves.get(i).scaleOnto((float) weights[i], total, first, last); + curves.get(i).scaleOnto((float) weights[i], total, first, last); } - SpectrumCalculations.subtractLists_target(context.data, context.total, context.residual, first, last); + SpectrumCalculations.subtractLists_target(ctx.data, eval.total, eval.residual, first, last); } @@ -236,12 +190,12 @@ private void test(double[] weights, int[] channels, EvaluationContext context) { * Score the context's residual spectrum */ // NB: Bytecode-optimized function. Take care making changes - private float score(double[] point, int[] channels, Spectrum residual) { + private float score(double[] point, FittingSolverContext ctx, Spectrum residual) { float[] ra = residual.backingArray(); float score = 0; - int length = channels.length; + int length = ctx.channels.length; for (int i = 0; i < length; i++) { - float value = ra[channels[i]]; + float value = ra[ctx.channels[i]]; //Negative values mean that we've fit more signal than exists //We penalize this to prevent making up data where none exists. @@ -260,36 +214,30 @@ private float score(double[] point, int[] channels, Spectrum residual) { * and a context. Scales the context.curves by the weights. Returns a new * FittingResultSet containing the fitted curves and other totals. */ - protected FittingResultSetView evaluate(double[] point, EvaluationContext context) { + public static FittingResultSetView evaluate(double[] point, FittingSolverContext ctx) { int index = 0; List fits = new ArrayList<>(); - Spectrum total = new ArraySpectrum(context.data.size()); - Spectrum scaled = new ArraySpectrum(context.data.size()); - for (CurveView curve : context.curves) { + Spectrum total = new ArraySpectrum(ctx.data.size()); + Spectrum scaled = new ArraySpectrum(ctx.data.size()); + for (CurveView curve : ctx.curves) { float scale = (float) point[index++]; curve.scaleInto(scale, scaled); fits.add(new FittingResult(curve, scale)); SpectrumCalculations.addLists_inplace(total, scaled); } - Spectrum residual = SpectrumCalculations.subtractLists(context.data, total); + Spectrum residual = SpectrumCalculations.subtractLists(ctx.data, total); - return new FittingResultSet(total, residual, fits, context.fittings.getFittingParameters().copy()); + return new FittingResultSet(total, residual, fits, ctx.fittings.getFittingParameters().copy()); } - public static class EvaluationContext { - public SpectrumView data; - public FittingSetView fittings; - public List curves; + public static class EvaluationSpace { public Spectrum scratch; public Spectrum total; public Spectrum residual; - public EvaluationContext(SpectrumView data, FittingSetView fittings, List curves) { - this.data = data; - this.fittings = fittings; - this.curves = curves; - this.scratch = new ArraySpectrum(data.size()); - this.total = new ArraySpectrum(data.size()); - this.residual = new ArraySpectrum(data.size()); + public EvaluationSpace(int size) { + this.scratch = new ArraySpectrum(size); + this.total = new ArraySpectrum(size); + this.residual = new ArraySpectrum(size); } }