Skip to content

Commit

Permalink
Merge pull request #76 from CHIMEFRB/new_def_scat_time
Browse files Browse the repository at this point in the history
New def scat time
  • Loading branch information
emmanuelfonseca authored Oct 18, 2023
2 parents e778cf1 + 248e7dc commit f48a428
Show file tree
Hide file tree
Showing 5 changed files with 799 additions and 382 deletions.
37 changes: 26 additions & 11 deletions fitburst/analysis/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.optimize import least_squares
import fitburst.routines.derivative as deriv
import numpy as np
import traceback
import sys

class LSFitter:
Expand Down Expand Up @@ -100,28 +101,30 @@ def compute_hessian(self, data: float, parameter_list: list) -> float:

# define the scale of the Hessian matrix and its labels.
par_labels_output = []
par_labels_idxs = []
par_labels = []
num_par = 0
print("here are the fit parameters:", self.fit_parameters)

for current_par in self.fit_parameters:
if np.any([current_par == x for x in self.global_parameters]):
par_labels_output += [current_par]
par_labels_idxs += [0]
par_labels += [current_par]
num_par += 1

else:
par_labels_output += [f"{current_par}{idx+1}" for idx in range(self.model.num_components)]
par_labels_idxs += [idx for idx in range(self.model.num_components)]
par_labels += ([current_par] * self.model.num_components)
num_par += (self.model.num_components)

# now loop over all fit parameters and compute derivatives.
hessian = np.zeros((num_par, num_par), dtype=float)

# now loop over all fit parameters and compute derivatives.
for current_par_idx_1, current_par_1 in zip(range(num_par), par_labels):
current_par_deriv_1 = getattr(deriv, f"deriv_model_wrt_{current_par_1}")
current_deriv_1 = current_par_deriv_1(
parameter_dict, self.model, component=(current_par_idx_1 % self.model.num_components)
parameter_dict, self.model, component=par_labels_idxs[current_par_idx_1]
)

# for efficient calculation, only compute one half of the matrix and fill the other half appropriately.
Expand All @@ -130,7 +133,7 @@ def compute_hessian(self, data: float, parameter_list: list) -> float:
# also compute a given derivative for all burst components.
current_par_deriv_2 = getattr(deriv, f"deriv_model_wrt_{current_par_2}")
current_deriv_2 = current_par_deriv_2(
parameter_dict, self.model, component=(current_par_idx_2 % self.model.num_components)
parameter_dict, self.model, component=par_labels_idxs[current_par_idx_2]
)

# correct name ordering of mixed partial derivative, if necessary.
Expand All @@ -140,17 +143,28 @@ def compute_hessian(self, data: float, parameter_list: list) -> float:
except AttributeError:
current_mixed_deriv = getattr(deriv, f"deriv2_model_wrt_{current_par_2}_{current_par_1}")

# only computed mixed derivative for parameters that describe the same component.
# only computed mixed derivative for *non-global* parameters that describe the same component,
# or any mixture of global/non-global or global-global pairs.
current_deriv_mixed = 0

if (current_par_idx_1 % self.model.num_components) == (current_par_idx_2 % self.model.num_components):
if (current_par_1 not in self.global_parameters) and (current_par_2 not in self.global_parameters) and \
(par_labels_idxs[current_par_idx_1] != par_labels_idxs[current_par_idx_2]):
pass

else:
# suss out correct pulse component and take mixed-partial derivative.
component = par_labels_idxs[current_par_idx_2]

if current_par_2 in self.global_parameters:
component = par_labels_idxs[current_par_idx_1]

current_deriv_mixed = current_mixed_deriv(
parameter_dict, self.model, component=(current_par_idx_2 % self.model.num_components)
parameter_dict, self.model, component=component
)

# finally, compute the hessian here.
current_hes = 2 * current_deriv_1 * current_deriv_2 - residual * current_deriv_mixed
hessian[current_par_idx_1, current_par_idx_2] = np.sum(current_hes * self.weights[:, None])
current_hes = 2 * (current_deriv_1 * current_deriv_2 - residual * current_deriv_mixed)
hessian[current_par_idx_1, current_par_idx_2] = np.sum(current_hes * self.weights[:, None] ** 2)
hessian[current_par_idx_2, current_par_idx_1] = hessian[current_par_idx_1, current_par_idx_2]

return hessian, par_labels_output
Expand Down Expand Up @@ -297,7 +311,7 @@ def fit(self, exact_jacobian: bool = True) -> None:

except Exception as exc:
print("ERROR: solver encountered a failure! Debug!")
print(sys.exc_info())
print(traceback.format_exc())

def fix_parameter(self, parameter_list: list) -> None:
"""
Expand Down Expand Up @@ -459,7 +473,7 @@ def _compute_fit_statistics(self, spectrum_observed: float, fit_result: object)
hessian_approx = fit_result.jac.T.dot(fit_result.jac)
covariance_approx = np.linalg.inv(hessian_approx) * chisq_final_reduced
hessian, par_labels = self.compute_hessian(self.data, self.fit_parameters)
covariance = np.linalg.inv(hessian) * chisq_final_reduced
covariance = np.linalg.inv(0.5 * hessian) * chisq_final_reduced
uncertainties = [float(x) for x in np.sqrt(np.diag(covariance)).tolist()]

self.covariance_approx = covariance_approx
Expand All @@ -474,6 +488,7 @@ def _compute_fit_statistics(self, spectrum_observed: float, fit_result: object)

except Exception as exc:
print(f"ERROR: {exc}; designating fit as unsuccessful...")
print(traceback.format_exc())
self.success = False

def _set_weights(self) -> None:
Expand Down
25 changes: 20 additions & 5 deletions fitburst/analysis/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,11 @@ def compute_model(self, data: float = None) -> float:
current_profile = self.compute_profile(
current_times_arr,
0.0, # since 'current_times' is already corrected for DM.
current_sc_time_scaled,
current_sc_time,
current_sc_idx,
current_width,
current_freq_arr[:, None],
current_ref_freq,
is_folded = self.is_folded,
)

Expand Down Expand Up @@ -258,8 +261,8 @@ def compute_model(self, data: float = None) -> float:

return np.sum(self.spectrum_per_component, axis=2)

def compute_profile(self, times: float, arrival_time: float, sc_time: float,
width: float, is_folded: bool = False) -> float:
def compute_profile(self, times: float, arrival_time: float, sc_time_ref: float, sc_index: float,
width: float, freqs: float, ref_freq: float, is_folded: bool = False) -> float:
"""
Returns the temporal profile, depending on input values of width
and scattering timescale.
Expand All @@ -272,12 +275,21 @@ def compute_profile(self, times: float, arrival_time: float, sc_time: float,
arrival_time : float
The arrival time of the burst
sc_time : float
sc_time_ref : float
The scattering timescale of the burst (which depends on frequency label)
sc_index : float
The index of frequency dependence on the scattering timescale
width : float
The intrinsic temporal width of the burst
freqs : float
The index of frequency dependence on the scattering timescale
ref_freq : float
The index of frequency dependence on the scattering timescale
is_folded : bool, optional
If true, then the temporal profile is computed over two realizations and then
averaged down to one (in order to allow for wrapping of a folded pulse shape)
Expand All @@ -304,6 +316,7 @@ def compute_profile(self, times: float, arrival_time: float, sc_time: float,

# compute either Gaussian or pulse-broadening function, depending on inputs.
profile = np.zeros(times_copy.shape, dtype=float)
sc_time = sc_time_ref * (freqs / ref_freq) ** sc_index

if np.any(sc_time < np.fabs(0.15 * width)):
profile = rt.profile.compute_profile_gaussian(times_copy, arrival_time, width)
Expand All @@ -313,7 +326,9 @@ def compute_profile(self, times: float, arrival_time: float, sc_time: float,
# floating-point overlow in the exp((-times - toa) / sc_time) term in the
# PBF call. TODO: use a better, more transparent method for avoiding this.
times_copy[times_copy < -5 * width] = -5 * width
profile = rt.profile.compute_profile_pbf(times_copy, arrival_time, width, sc_time)
profile = rt.profile.compute_profile_pbf(
times_copy, arrival_time, width, freqs, ref_freq, sc_time_ref, sc_index=sc_index
)

# if data are folded and time/profile data contain two realizations, then
# average along the appropriate axis to obtain a single realization.
Expand Down
4 changes: 2 additions & 2 deletions fitburst/pipelines/fitburst_example_chimefrb.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
)

parser.add_argument(
"--downsample_freq", action="store", dest="factor_freq_downsample", default=64, type=int,
"--downsample_freq", action="store", dest="factor_freq_downsample", default=1, type=int,
help="Downsample the raw spectrum along the frequency axis by a specified integer."
)

Expand Down Expand Up @@ -341,7 +341,7 @@
)

# if desired, downsample data prior to extraction.
data.downsample(factor_freq_downsample, factor_time_upsample)
#data.downsample(factor_freq_downsample, factor_time_upsample)
log.info(f"downsampled raw data by factors of (ds_freq, ds_time) = ({factor_freq_downsample}, {factor_time_downsample})")

# if the number of RFI-flagged channels is "too large", skip this event altogether.
Expand Down
Loading

0 comments on commit f48a428

Please sign in to comment.