Skip to content

Commit

Permalink
Move some values for fitting solvers into FittingSolverContext to all…
Browse files Browse the repository at this point in the history
…ow better reuse
  • Loading branch information
nathanielsherry committed Mar 19, 2024
1 parent ec9373d commit 1eecbcf
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 133 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
package org.peakaboo.curvefit.curve.fitting.solver;

import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

import org.peakaboo.curvefit.curve.fitting.Curve;
import org.peakaboo.curvefit.curve.fitting.CurveView;
import org.peakaboo.curvefit.curve.fitting.FittingResultSetView;
import org.peakaboo.curvefit.curve.fitting.FittingSet;
import org.peakaboo.curvefit.curve.fitting.FittingSetView;
import org.peakaboo.curvefit.curve.fitting.fitter.CurveFitter;
import org.peakaboo.curvefit.peak.table.Element;
import org.peakaboo.curvefit.peak.transition.TransitionShell;
import org.peakaboo.framework.bolt.plugin.java.BoltJavaPlugin;
import org.peakaboo.framework.cyclops.spectrum.Spectrum;
import org.peakaboo.framework.cyclops.spectrum.SpectrumView;
Expand All @@ -16,8 +24,99 @@
*/
public interface FittingSolver extends BoltJavaPlugin {

public static record FittingSolverContext (SpectrumView data, FittingSetView fittings, CurveFitter fitter) {};
public static class FittingSolverContext {

// Given Values
/**
* The spectrum for which we're performing fit solving
*/
public SpectrumView data;

/**
* The fittings to be solved
*/
public FittingSetView fittings;

/**
* The curve fitter which will perform single-curve fitting
*/
public CurveFitter fitter;



// Derived Values

/**
* Sorted list of channels
*/
public List<CurveView> curves;

/**
* Subset of channels to focus on
*/
int[] channels;

public FittingSolverContext(SpectrumView data, FittingSetView fittings, CurveFitter fitter) {
this.data = data;
this.fittings = fittings;
this.fitter = fitter;

// Generate sorted list of curves from visible fittings
curves = new ArrayList<>(fittings.getVisibleCurves());
sortCurves(curves);

// Calculate a list of channels with enough curve signal to matter
channels = getIntenseChannels(curves);
}

/**
* Performs a shallow copy of another {@link FittingSolverContext}
*/
public FittingSolverContext(FittingSolverContext copy) {
this.data = copy.data;
this.fittings = copy.fittings;
this.fitter = copy.fitter;
this.curves = copy.curves;
this.channels = copy.channels;
}

}

FittingResultSetView solve(FittingSolverContext ctx);

/**
* Given a list of curves, sort them by by shell first, and then by element
*/
static void sortCurves(List<CurveView> curves) {
curves.sort((a, b) -> {
TransitionShell as, bs;
as = a.getTransitionSeries().getShell();
bs = b.getTransitionSeries().getShell();
Element ae, be;
ae = a.getTransitionSeries().getElement();
be = b.getTransitionSeries().getElement();
if (as.equals(bs)) {
return ae.compareTo(be);
} else {
return as.compareTo(bs);
}
});
}

static int[] getIntenseChannels(List<CurveView> curves) {
Set<Integer> intenseChannels = new LinkedHashSet<>();
for (CurveView curve : curves) {
intenseChannels.addAll(curve.getIntenseChannels());
}
List<Integer> asList = new ArrayList<>(intenseChannels);
asList.sort(Integer::compare);
int[] asArr = new int[asList.size()];
for (int i = 0; i < asArr.length; i++) {
asArr[i] = asList.get(i);
}
return asArr;
}



}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@
import org.peakaboo.curvefit.curve.fitting.FittingResultSet;
import org.peakaboo.curvefit.curve.fitting.FittingResultSetView;
import org.peakaboo.curvefit.curve.fitting.FittingResultView;
import org.peakaboo.curvefit.curve.fitting.FittingSetView;
import org.peakaboo.curvefit.curve.fitting.fitter.CurveFitter;
import org.peakaboo.curvefit.curve.fitting.fitter.CurveFitter.CurveFitterContext;
import org.peakaboo.framework.cyclops.spectrum.ArraySpectrum;
import org.peakaboo.framework.cyclops.spectrum.SpectrumView;
import org.peakaboo.framework.cyclops.spectrum.Spectrum;
import org.peakaboo.framework.cyclops.spectrum.SpectrumCalculations;

Expand Down Expand Up @@ -49,23 +46,19 @@ public String pluginUUID() {
*/
@Override
public FittingResultSetView solve(FittingSolverContext ctx) {

SpectrumView data = ctx.data();
FittingSetView fittings = ctx.fittings();
CurveFitter fitter = ctx.fitter();

Spectrum resultTotalFit = new ArraySpectrum(data.size());
Spectrum resultTotalFit = new ArraySpectrum(ctx.data.size());
List<FittingResultView> resultFits = new ArrayList<>();
FittingParametersView resultParameters = fittings.getFittingParameters().copy();
FittingParametersView resultParameters = ctx.fittings.getFittingParameters().copy();

Spectrum remainder = new ArraySpectrum(data);
Spectrum scaled = new ArraySpectrum(data.size());
Spectrum remainder = new ArraySpectrum(ctx.data);
Spectrum scaled = new ArraySpectrum(ctx.data.size());

// calculate the curves
for (CurveView curve : fittings.getCurves()) {
for (CurveView curve : ctx.fittings.getCurves()) {
if (!curve.getTransitionSeries().isVisible()) { continue; }

FittingResult result = fitter.fit(new CurveFitterContext(remainder, curve));
FittingResult result = ctx.fitter.fit(new CurveFitterContext(remainder, curve));
curve.scaleInto(result.getCurveScale(), scaled);
SpectrumCalculations.subtractLists_inplace(remainder, scaled, 0.0f);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;

import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optim.PointValuePair;
import org.peakaboo.curvefit.curve.fitting.CurveView;
import org.peakaboo.curvefit.curve.fitting.FittingResultSetView;
import org.peakaboo.curvefit.curve.fitting.FittingSetView;
import org.peakaboo.curvefit.curve.fitting.fitter.CurveFitter;
import org.peakaboo.framework.cyclops.spectrum.SpectrumView;

public class MultisamplingOptimizingFittingSolver extends OptimizingFittingSolver {

Expand Down Expand Up @@ -41,31 +37,27 @@ public String pluginUUID() {
}

@Override
public FittingResultSetView solve(FittingSolverContext ctx) {
public FittingResultSetView solve(FittingSolverContext inputCtx) {

SpectrumView data = ctx.data();
FittingSetView fittings = ctx.fittings();
CurveFitter fitter = ctx.fitter();

int size = fittings.getVisibleCurves().size();
int size = inputCtx.fittings.getVisibleCurves().size();
if (size == 0) {
return getEmptyResult(data, fittings);
return getEmptyResult(inputCtx);
}

List<CurveView> curves = fittings.getVisibleCurves();
sortCurves(curves);
int[] intenseChannels = getIntenseChannels(curves);
// Create a shallow copy of the input context and then make a deep copy of the
// curve list so that we can permute it without impacting the original
FittingSolverContext permCtx = new FittingSolverContext(inputCtx);
permCtx.curves = new ArrayList<>(permCtx.curves);

List<CurveView> perm = new ArrayList<>(curves);
int counter = 0;
double[] scalings = new double[size];
EvaluationSpace eval = new EvaluationSpace(permCtx.data.size());
while (counter <= 10) {
Collections.shuffle(perm, new Random(12345654321l));
Collections.shuffle(permCtx.curves, new Random(12345654321l));


double[] guess = getInitialGuess(perm, fitter, data);
EvaluationContext context = new EvaluationContext(data, fittings, perm);
MultivariateFunction cost = getCostFunction(context, intenseChannels);
double[] guess = getInitialGuess(permCtx);
MultivariateFunction cost = getCostFunction(permCtx, eval);
PointValuePair result = optimizeCostFunction(cost, guess, 0.02d);
double[] permScalings = result.getPoint();

Expand All @@ -75,8 +67,9 @@ public FittingResultSetView solve(FittingSolverContext ctx) {
//guess = permScalings;

for (int i = 0; i < scalings.length; i++) {
CurveView c = perm.get(i);
int j = curves.indexOf(c);
// Map the scores back to the original permutation of the curves list
CurveView c = permCtx.curves.get(i);
int j = inputCtx.curves.indexOf(c);
scalings[j] += permScalings[i];
}
counter++;
Expand All @@ -86,10 +79,8 @@ public FittingResultSetView solve(FittingSolverContext ctx) {
for (int i = 0; i < scalings.length; i++) {
scalings[i] /= counter;
}

EvaluationContext context = new EvaluationContext(data, fittings, curves);

return evaluate(scalings, context);
return evaluate(scalings, inputCtx);


}
Expand Down
Loading

0 comments on commit 1eecbcf

Please sign in to comment.