Skip to content

Commit

Permalink
Merge pull request #52 from ISISComputingGroup/3_uncertainties_in_plo…
Browse files Browse the repository at this point in the history
…ts_and_fits

Use uncertainties in plots and fits
  • Loading branch information
Tom-Willemsen authored Nov 20, 2024
2 parents 8734c86 + ed4a9e0 commit ef57285
Show file tree
Hide file tree
Showing 11 changed files with 337 additions and 45 deletions.
58 changes: 58 additions & 0 deletions doc/architectural_decisions/005-variance-addition.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Variance addition to counts data

## Status

Current

## Context

For counts data, the uncertainty on counts is typically defined by poisson counting statistics, i.e. the standard deviation on `N` counts is `sqrt(N)`.

This can be problematic in cases where zero counts have been collected, as the standard deviation will then be zero, which will subsequently lead to "infinite" point weightings in downstream fitting routines for example.

A number of possible approaches were considered:

| Option | Description |
| --- | --- |
| A | Reject data with zero counts, i.e. explicitly throw an exception if any data with zero counts is seen as part of a scan. |
| B | Use a standard deviation of `NaN` for points with zero counts. |
| C | Define the standard deviation of `N` counts as `1` if counts are zero, otherwise `sqrt(N)`. This is one of the approaches available in mantid for example. |
| D | Define the standard deviation of `N` counts as `sqrt(N+0.5)` unconditionally - on the basis that "half a count" is smaller than the smallest possible actual measurement which can be taken. |
| E | No special handling, calculate std. dev. as `sqrt(N)`. |

For clarity, the following table shows the value and associated uncertainty for each option:

| Counts | Std. Dev. (A) | Std. Dev. (B) | Std. Dev. (C) | Std. Dev. (D) | Std. Dev. (E) |
| ------- | ------ | ------- | ------- | ------- | --- |
| 0 | raise exception | NaN | 1 | 0.707 | 0 |
| 1 | 1 | 1 | 1 | 1.224745 | 1 |
| 2 | 1.414214 | 1.414214 | 1.414214 | 1.581139 | 1.414214 |
| 3 | 1.732051 | 1.732051 | 1.732051 | 1.870829 | 1.732051 |
| 4 | 2 | 2 | 2 | 2.12132 | 2 |
| 5 | 2.236068 | 2.236068 | 2.236068 | 2.345208 | 2.236068 |
| 10 | 3.162278 | 3.162278 | 3.162278 | 3.24037 | 3.162278 |
| 50 | 7.071068 | 7.071068 | 7.071068 | 7.106335 | 7.071068 |
| 100 | 10 | 10 | 10 | 10.02497 | 10 |
| 500 | 22.36068 | 22.36068 | 22.36068 | 22.37186 | 22.36068 |
| 1000 | 31.62278 | 31.62278 | 31.62278 | 31.63068 | 31.62278 |
| 5000 | 70.71068 | 70.71068 | 70.71068 | 70.71421 | 70.71068 |
| 10000 | 100 | 100 | 100 | 100.0025 | 100 |

## Present

These approaches were discussed in a regular project update meeting including
- TW & FA (Experiment controls)
- CK (Reflectometry)
- JL (Muons)
- RD (SANS)

## Decision

The consensus was to go with Option D.

## Justification

- Option A will cause real-life scans to crash in low counts regions.
- Option B involves `NaN`s, which have many surprising floating-point characteristics and are highly likely to be a source of future bugs.
- Option D was preferred to option C by scientists present.
- Option E causes surprising results and/or crashes downstream, for example fitting may consider points with zero uncertainty to have "infinite" weight, therefore effectively disregarding all other data.
5 changes: 4 additions & 1 deletion doc/callbacks/plotting.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@ ax = plt.gca()
# Set the y-scale to logarithmic
ax.set_yscale("log")
# Use the above axes in a LivePlot callback
plot_callback = LivePlot(y="y_variable", x="x_variable", ax=ax)
plot_callback = LivePlot(y="y_variable", x="x_variable", ax=ax, yerr="yerr_variable")
# yerr is the uncertanties of each y value, producing error bars
```

By providing a signal name to the `yerr` argument you can pass uncertainties to LivePlot, by not providing anything for this argument means that no errorbars will be drawn. Errorbars are drawn after each point collected, displaying their standard deviation- uncertainty data is collected from Bluesky event documents and errorbars are updated after every new point added.

The `plot_callback` object can then be subscribed to the run engine, using either:
- An explicit callback when calling the run engine: `RE(some_plan(), plot_callback)`
- Be subscribed in a plan using `@subs_decorator` from bluesky **(recommended)**
Expand Down
17 changes: 10 additions & 7 deletions doc/fitting/fitting.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@ plt.figure()
ax = plt.gca()
# ax is shared by fit_callback and plot_callback

plot_callback = LivePlot(y="y_variable", x="x_variable", ax=ax)
fit_callback = LiveFit(Gaussian.fit(), y="y_variable", x="x_variable", update_every=0.5)
plot_callback = LivePlot(y="y_signal", x="x_signal", ax=ax, yerr="yerr_signal")
fit_callback = LiveFit(Gaussian.fit(), y="y_signal", x="x_signal", yerr="yerr_signal", update_every=0.5)
# Using the yerr parameter allows you to use error bars.
# update_every = in seconds, how often to recompute the fit. If `None`, do not compute until the end. Default is 1.
fit_plot_callback = LiveFitPlot(fit_callback, ax=ax, color="r")
```

**Note:** that the `LiveFit` callback doesn't directly do the plotting, it will return function parameters of the model its trying to fit to; a `LiveFit` object must be passed to `LiveFitPlot` which can then be subscribed to the `RunEngine`. See the [Bluesky Documentation](https://blueskyproject.io/bluesky/main/callbacks.html#livefitplot) for information on the various arguments that can be passed to the `LiveFitPlot` class.

Using the `yerr` argument allows you to pass uncertainties via a signal to LiveFit, so that the "weight" of each point influences the fit produced. By not providing a signal name you choose not to use uncertainties/weighting in the fitting calculation. Each weight is computed as `1/(standard deviation at point)` and is taken into account to determine how much a point affects the overall fit of the data. Same as the rest of `LiveFit`, the fit will be updated after every new point collected now taking into account the weights of each point. Uncertainty data is collected from Bluesky event documents after each new point.

The `plot_callback` and `fit_plot_callback` objects can then be subscribed to the `RunEngine`, using the same methods as described in [`LivePlot`](../callbacks/plotting.md). See the following example using `@subs_decorator`:

```py
Expand Down Expand Up @@ -79,7 +82,7 @@ from bluesky.callbacks import LiveFitPlot
from ibex_bluesky_core.callbacks.fitting.fitting_utils import [FIT]

# Pass [FIT].fit() to the first parameter of LiveFit
lf = LiveFit([FIT].fit(), y="y_variable", x="x_variable", update_every=0.5)
lf = LiveFit([FIT].fit(), y="y_signal", x="x_signal", update_every=0.5)

# Then subscribe to LiveFitPlot(lf, ...)
```
Expand All @@ -89,7 +92,7 @@ The `[FIT].fit()` function will pass the `FitMethod` object straight to the `Liv
**Note:** that for the fits in the above table that require parameters, you will need to pass value(s) to their `.fit` method. For example Polynomial fitting:

```py
lf = LiveFit(Polynomial.fit(3), y="y_variable", x="x_variable", update_every=0.5)
lf = LiveFit(Polynomial.fit(3), y="y_signal", x="x_signal", update_every=0.5)
# For a polynomial of degree 3
```

Expand Down Expand Up @@ -138,7 +141,7 @@ def guess(x: npt.NDArray[np.float64], y: npt.NDArray[np.float64]) -> dict[str, l
fit_method = FitMethod(model, guess)
#Pass the model and guess function to FitMethod

lf = LiveFit(fit_method, y="y_variable", x="x_variable", update_every=0.5)
lf = LiveFit(fit_method, y="y_signal", x="x_signal", update_every=0.5)

# Then subscribe to LiveFitPlot(lf, ...)
```
Expand All @@ -163,7 +166,7 @@ def different_model(x: float, c1: float, c0: float) -> float:
fit_method = FitMethod(different_model, Linear.guess())
# Uses the user defined model and the standard Guessing. function for linear models

lf = LiveFit(fit_method, y="y_variable", x="x_variable", update_every=0.5)
lf = LiveFit(fit_method, y="y_signal", x="x_signal", update_every=0.5)

# Then subscribe to LiveFitPlot(lf, ...)
```
Expand All @@ -188,7 +191,7 @@ def different_guess(x: float, c1: float, c0: float) -> float:
fit_method = FitMethod(Linear.model(), different_guess)
# Uses the standard linear model and the user defined Guessing. function

lf = LiveFit(fit_method, y="y_variable", x="x_variable", update_every=0.5)
lf = LiveFit(fit_method, y="y_signal", x="x_signal", update_every=0.5)

# Then subscribe to LiveFitPlot(lf, ...)
```
Expand Down
27 changes: 22 additions & 5 deletions manual_system_tests/dae_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import bluesky.plans as bp
import matplotlib
import matplotlib.pyplot as plt
from bluesky.callbacks import LiveTable
from bluesky.callbacks import LiveFitPlot, LiveTable
from bluesky.preprocessors import subs_decorator
from bluesky.utils import Msg
from ophyd_async.plan_stubs import ensure_connected

from ibex_bluesky_core.callbacks.file_logger import HumanReadableFileCallback
from ibex_bluesky_core.callbacks.fitting import LiveFit
from ibex_bluesky_core.callbacks.fitting.fitting_utils import Linear
from ibex_bluesky_core.callbacks.plotting import LivePlot
from ibex_bluesky_core.devices import get_pv_prefix
from ibex_bluesky_core.devices.block import block_rw_rbv
Expand All @@ -27,6 +29,8 @@
from ibex_bluesky_core.devices.simpledae.waiters import GoodFramesWaiter
from ibex_bluesky_core.run_engine import get_run_engine

NUM_POINTS: int = 3


def dae_scan_plan() -> Generator[Msg, None, None]:
"""Manual system test which moves a block and reads the DAE.
Expand Down Expand Up @@ -67,6 +71,11 @@ def dae_scan_plan() -> Generator[Msg, None, None]:
controller.run_number.set_name("run number")
reducer.intensity.set_name("normalized counts")

_, ax = plt.subplots()
lf = LiveFit(
Linear.fit(), y=reducer.intensity.name, x=block.name, yerr=reducer.intensity_stddev.name
)

yield from ensure_connected(block, dae, force_reconnect=True)

@subs_decorator(
Expand All @@ -81,7 +90,15 @@ def dae_scan_plan() -> Generator[Msg, None, None]:
dae.good_frames.name,
],
),
LivePlot(y=reducer.intensity.name, x=block.name, marker="x", linestyle="none"),
LiveFitPlot(livefit=lf, ax=ax),
LivePlot(
y=reducer.intensity.name,
x=block.name,
marker="x",
linestyle="none",
ax=ax,
yerr=reducer.intensity_stddev.name,
),
LiveTable(
[
block.name,
Expand All @@ -96,9 +113,9 @@ def dae_scan_plan() -> Generator[Msg, None, None]:
]
)
def _inner() -> Generator[Msg, None, None]:
num_points = 3
yield from bps.mv(dae.number_of_periods, num_points)
yield from bp.scan([dae], block, 0, 10, num=num_points)
yield from bps.mv(dae.number_of_periods, NUM_POINTS) # type: ignore
# Pyright does not understand as bluesky isn't typed yet
yield from bp.scan([dae], block, 0, 10, num=NUM_POINTS)

yield from _inner()

Expand Down
74 changes: 54 additions & 20 deletions src/ibex_bluesky_core/callbacks/fitting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""For IBEX Bluesky scan fitting."""

import logging
import warnings
from typing import Callable

import lmfit
import numpy as np
import numpy.typing as npt
from bluesky.callbacks import LiveFit as _DefaultLiveFit
from bluesky.callbacks.core import make_class_safe
from event_model.documents.event import Event

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,31 +51,49 @@ class LiveFit(_DefaultLiveFit):
"""Live fit, customized for IBEX."""

def __init__(
self,
method: FitMethod,
y: str,
x: str,
*,
update_every: int = 1,
self, method: FitMethod, y: str, x: str, *, update_every: int = 1, yerr: str | None = None
) -> None:
"""Call Bluesky LiveFit with assumption that there is only one independant variable.
Args:
method (FitMethod): The FitMethod (Model & Guess) to use when fitting.
y (str): The name of the dependant variable.
x (str): The name of the independant variable.
update_every (int): How often to update the fit. (seconds)
update_every (int, optional): How often to update the fit. (seconds)
yerr (str or None, optional): Name of field in the Event document
that provides standard deviation for each Y value. None meaning
do not use uncertainties in fit.
"""
self.method = method
self.yerr = yerr
self.weight_data = []

super().__init__(
model=method.model,
y=y,
independent_vars={"x": x},
update_every=update_every,
model=method.model, y=y, independent_vars={"x": x}, update_every=update_every
)

def event(self, doc: Event) -> None:
"""When an event is received, update caches."""
weight = None
if self.yerr is not None:
try:
weight = 1 / doc["data"][self.yerr]
except ZeroDivisionError:
warnings.warn(
"standard deviation for y is 0, therefore applying weight of 0 on fit",
stacklevel=1,
)
weight = 0.0

self.update_weight(weight)
super().event(doc)

def update_weight(self, weight: float | None = 0.0) -> None:
"""Update uncertainties cache."""
if self.yerr is not None:
self.weight_data.append(weight)

def update_fit(self) -> None:
"""Use the provided guess function with the most recent x and y values after every update.
Expand All @@ -84,12 +104,26 @@ def update_fit(self) -> None:
None
"""
logger.debug("updating guess for %s ", self.method)
self.init_guess = self.method.guess(
np.array(next(iter(self.independent_vars_data.values()))),
np.array(self.ydata),
# Calls the guess function on the set of data already collected in the run
)
logger.info("new guess for %s: %s", self.method, self.init_guess)

super().update_fit()
n = len(self.model.param_names)
if len(self.ydata) < n:
warnings.warn(
f"LiveFitPlot cannot update fit until there are at least {n} data points",
stacklevel=1,
)
else:
logger.debug("updating guess for %s ", self.method)
self.init_guess = self.method.guess(
np.array(next(iter(self.independent_vars_data.values()))),
np.array(self.ydata),
# Calls the guess function on the set of data already collected in the run
)

logger.info("new guess for %s: %s", self.method, self.init_guess)

kwargs = {}
kwargs.update(self.independent_vars_data)
kwargs.update(self.init_guess)
self.result = self.model.fit(
self.ydata, weights=None if self.yerr is None else self.weight_data, **kwargs
)
self.__stale = False
1 change: 0 additions & 1 deletion src/ibex_bluesky_core/callbacks/fitting/fitting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def guess(
) -> Callable[[npt.NDArray[np.float64], npt.NDArray[np.float64]], dict[str, lmfit.Parameter]]:
"""Linear Guessing."""
return Polynomial.guess(1)
# Uses polynomial guessing with a degree of 1


class Polynomial(Fit):
Expand Down
Loading

0 comments on commit ef57285

Please sign in to comment.