Skip to content

Commit

Permalink
Merge pull request #77 from CHIMEFRB/new_def_scat_time
Browse files Browse the repository at this point in the history
New def scat time
  • Loading branch information
emmanuelfonseca authored Nov 10, 2023
2 parents f48a428 + 7eda71a commit 77b0136
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 111 deletions.
22 changes: 18 additions & 4 deletions fitburst/analysis/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class LSFitter:
least-squares fitting of radio dynamic spectra.
"""

def __init__(self, data: float, model: object, good_freq: bool, weighted_fit: bool = True):
def __init__(self, data: float, model: object, good_freq: bool, weighted_fit: bool = True,
weight_range: list = None):
"""
Initializes object with methods and attributes defined in
the model.SpectrumModeler() class.
Expand All @@ -37,6 +38,12 @@ def __init__(self, data: float, model: object, good_freq: bool, weighted_fit: bo
If set to true, then each channel will be weighted by its standard deviation (or,
equivalently, the goodness of fit statistic will be a weighted chi-squared value.)
weight_range : list, optional
If set, then per-channel weights will be computed using a time range specified
by indices of the time array (i.e., this option should receive a list of length
two containing integers within the range 0 ... [num_time-1].) If not set, then
the full range will be used.
Returns
-------
None : NoneType
Expand Down Expand Up @@ -66,6 +73,7 @@ def __init__(self, data: float, model: object, good_freq: bool, weighted_fit: bo
self.success = None
self.weights = None
self.weighted_fit = weighted_fit
self.weight_range = weight_range

# before running fit, determine per-channel weights.
self._set_weights()
Expand Down Expand Up @@ -471,9 +479,9 @@ def _compute_fit_statistics(self, spectrum_observed: float, fit_result: object)

try:
hessian_approx = fit_result.jac.T.dot(fit_result.jac)
covariance_approx = np.linalg.inv(hessian_approx) * chisq_final_reduced
covariance_approx = np.linalg.inv(hessian_approx)
hessian, par_labels = self.compute_hessian(self.data, self.fit_parameters)
covariance = np.linalg.inv(0.5 * hessian) * chisq_final_reduced
covariance = np.linalg.inv(0.5 * hessian)
uncertainties = [float(x) for x in np.sqrt(np.diag(covariance)).tolist()]

self.covariance_approx = covariance_approx
Expand Down Expand Up @@ -506,8 +514,14 @@ def _set_weights(self) -> None:
Two object attributes are defined and used for masking and weighting data during fit.
"""

# before calculating, determine in the range must be set to the full timespan.
idx_range = self.weight_range

if self.weight_range is None:
idx_range = [0, len(self.data[0, :]) - 1]

# compute RMS deviation for each channel.
variance = np.mean(self.data**2, axis=1)
variance = np.mean(self.data[:, idx_range[0] : idx_range[1]] ** 2, axis=1)
std = np.sqrt(variance)
bad_freq = np.logical_not(self.good_freq)

Expand Down
119 changes: 72 additions & 47 deletions fitburst/analysis/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,12 @@ def __init__(self, freqs: float, times: float, dm_incoherent: float = 0.,
for current_parameter in self.parameters:
setattr(self, current_parameter, None)

# now instantiate the structures for per-component models and time differences.
# now instantiate the structures for per-component models, time differences, and temporal profiles.
# (the following are used for computing derivatives and/or per-channel amplitudes.)
self.amplitude_per_component = np.zeros(
(self.num_freq, self.num_time, self.num_components), dtype=float
)

self.spectrum_per_component = np.zeros(
(self.num_freq, self.num_time, self.num_components), dtype=float
)
Expand All @@ -114,6 +119,10 @@ def __init__(self, freqs: float, times: float, dm_incoherent: float = 0.,
(self.num_freq, self.num_time, self.num_components), dtype=float
)

self.timeprof_per_component = np.zeros(
(self.num_freq, self.num_time, self.num_components), dtype=float
)

def compute_model(self, data: float = None) -> float:
"""
Computes the model dynamic spectrum based on model parameters (set as class
Expand All @@ -135,30 +144,32 @@ def compute_model(self, data: float = None) -> float:
num_window_bins = self.num_time // 2

# loop over all components.
for current_component in range(self.num_components):

# extract parameter values for current component.
current_amplitude = self.amplitude[current_component]
current_arrival_time = self.arrival_time[current_component]
current_dm = self.dm[0]
current_dm_index = self.dm_index[0]
current_ref_freq = self.ref_freq[current_component]
current_sc_idx = self.scattering_index[0]
current_sc_time = self.scattering_timescale[0]
current_width = self.burst_width[current_component]

if self.verbose:
if self.scintillation:
print(
f"{current_dm:.5f} {current_arrival_time:.5f} ",
f"{current_sc_idx:.5f} {current_sc_time:.5f} {current_width:.5f}", end=" ")
else:
print(
f"{current_dm:.5f} {current_amplitude:.5f} {current_arrival_time:.5f} ",
f"{current_sc_idx:.5f} {current_sc_time:.5f} {current_width:.5f}", end=" ")
for current_freq in range(self.num_freq):

# now loop over bandpass.
for current_freq in range(self.num_freq):
for current_component in range(self.num_components):

# extract parameter values for current component.
current_amplitude = self.amplitude[current_component]
current_arrival_time = self.arrival_time[current_component]
current_dm = self.dm[0]
current_dm_index = self.dm_index[0]
current_ref_freq = self.ref_freq[current_component]
current_sc_idx = self.scattering_index[0]
current_sc_time = self.scattering_timescale[0]
current_sp_idx = self.spectral_index[current_component]
current_sp_run = self.spectral_running[current_component]
current_width = self.burst_width[current_component]

if self.verbose and current_freq == 0:
if self.scintillation:
print(
f"{current_dm:.5f} {current_arrival_time:.5f} ",
f"{current_sc_idx:.5f} {current_sc_time:.5f} {current_width:.5f}", end=" ")
else:
print(
f"{current_dm:.5f} {current_amplitude:.5f} {current_arrival_time:.5f} ",
f"{current_sc_idx:.5f} {current_sc_time:.5f} {current_width:.5f}", end=" ")

# create an upsampled version of the current frequency label.
# even if no upsampling is desired, this will return an array
Expand Down Expand Up @@ -224,41 +235,55 @@ def compute_model(self, data: float = None) -> float:
is_folded = self.is_folded,
)

# third, compute and scale profile by spectral energy distribution.
if not self.scintillation:
current_sp_idx = self.spectral_index[current_component]
current_sp_run = self.spectral_running[current_component]
self.timeprof_per_component[current_freq, :, current_component] = rt.manipulate.downsample_1d(
current_profile.mean(axis=0),
self.factor_time_upsample
)

current_profile *= rt.spectrum.compute_spectrum_rpl(
current_freq_arr,
current_ref_freq,
current_sp_idx,
current_sp_run,
)[:, None]
# third, compute and scale profile by spectral energy distribution.
current_profile *= rt.spectrum.compute_spectrum_rpl(
current_freq_arr,
current_ref_freq,
current_sp_idx,
current_sp_run,
)[:, None]

# before writing, downsize upsampled array to original size.
current_profile = rt.manipulate.downsample_1d(
current_profile.mean(axis=0),
self.factor_time_upsample
)

# finally, add to approrpiate slice of model-spectrum matrix.
if self.scintillation:
current_amplitude = rt.ism.compute_amplitude_per_channel(
data[current_freq], current_profile
)
self.spectrum_per_component[current_freq, :, current_component] = current_amplitude * current_profile
# before exiting the loop, save different snapshots of the model.
self.amplitude_per_component[current_freq, :, current_component] = rt.spectrum.compute_spectrum_rpl(
self.freqs[current_freq],
current_ref_freq,
current_sp_idx,
current_sp_run
) * (10 ** current_amplitude)
self.spectrum_per_component[current_freq, :, current_component] = (10 ** current_amplitude) * current_profile

# print spectral index/running for current component.
if current_freq == 0:
if self.verbose and not self.scintillation:
print(f"{current_sp_idx:.5f} {current_sp_run:.5f}")

else:
self.spectrum_per_component[current_freq, :, current_component] = (10**current_amplitude) * current_profile
else:
print()

# print spectral index/running for current component.
if self.verbose and not self.scintillation:
print(f"{current_sp_idx:.5f} {current_sp_run:.5f}")
# if desired, then compute per-channel amplitudes in cases where scintillation is significant.
if self.scintillation:

for freq in range(self.num_freq):
current_amplitudes = rt.ism.compute_amplitude_per_channel(
data[freq], self.timeprof_per_component[freq, :, :]
)
# now compute model with per-channel amplitudes determined.
for component in range(self.num_components):
current_profile = self.timeprof_per_component[freq, :, component]
self.amplitude_per_component[freq, :, component] = current_amplitudes[component]
self.spectrum_per_component[freq, :, component] = current_amplitudes[component] * current_profile

else:
print()

return np.sum(self.spectrum_per_component, axis=2)

def compute_profile(self, times: float, arrival_time: float, sc_time_ref: float, sc_index: float,
Expand Down
15 changes: 7 additions & 8 deletions fitburst/pipelines/fitburst_example_chimefrb.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,9 @@
"scattering_timescale" not in parameters_to_fix
):
initial_parameters["scattering_timescale"] = copy.deepcopy(
(np.fabs(np.array(initial_parameters["burst_width"])) * 3.).tolist()
(np.fabs(np.array(initial_parameters["burst_width"])) * 1.).tolist()
)
initial_parameters["burst_width"] = (np.array(initial_parameters["burst_width"]) / 1.5).tolist()
initial_parameters["burst_width"] = (np.array(initial_parameters["burst_width"]) / 1.).tolist()

# if guesses are provided through CLI, overload them into the initial-guess dictionary.
initial_parameters["dm"][0] += offset_dm
Expand Down Expand Up @@ -433,11 +433,12 @@
verbose=verbose,
)

print(initial_parameters)
model.update_parameters(initial_parameters)
bestfit_model = model.compute_model(data=data_windowed) * data.good_freq[:, None]
bestfit_params = model.get_parameters_dict()
bestfit_params["dm"] = [params["dm"][0] + x for x in bestfit_params["dm"] * model.num_components]
bestfit_params["dm"] = [params["dm"][0] + x for x in bestfit_params["dm"]]
#print(bestfit_params["dm"])
#sys.exit()
bestfit_residuals = data_windowed - bestfit_model
fit_is_successful = False
fit_statistics = None
Expand All @@ -459,13 +460,11 @@
model.update_parameters(fitter.fit_statistics["bestfit_parameters"])
bestfit_model = model.compute_model(data=data_windowed) * data.good_freq[:, None]
bestfit_params = model.get_parameters_dict()
bestfit_params["dm"] = [params["dm"][0] + x for x in bestfit_params["dm"] * model.num_components]
bestfit_params["dm"] = [params["dm"][0] + x for x in bestfit_params["dm"]]
bestfit_residuals = data_windowed - bestfit_model
fit_is_successful = True
fit_statistics = fitter.fit_statistics
plt.pcolormesh(bestfit_model)
plt.savefig("test2.png")


# TODO: for now, stash covariance data for offline comparison; remove at some point.
np.savez(
f"covariance_matrices_{current_event_id}.npz",
Expand Down
16 changes: 13 additions & 3 deletions fitburst/pipelines/fitburst_example_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,16 @@
"fit parameters and fitting."
)

parser.add_argument(
"--weight_range",
action="store",
dest="weight_range",
default=None,
nargs=2,
type=int,
help="Indeces for timestamp array that represent region to evaluate RMS data weights."
)

parser.add_argument(
"--width",
action="store",
Expand Down Expand Up @@ -313,6 +323,7 @@
solution_file = args.solution_file
variance_range = args.variance_range
verbose = args.verbose
weight_range = args.weight_range
width = args.width
window = args.window

Expand Down Expand Up @@ -361,7 +372,7 @@
data.good_freq[idx_freq] = False

if preprocess_data:
data.preprocess_data(normalize_variance=False, variance_range=variance_range)
data.preprocess_data(normalize_variance=True, variance_range=variance_range)

print(f"There are {data.good_freq.sum()} good frequencies...")

Expand Down Expand Up @@ -497,10 +508,9 @@

# now set up fitter and execute least-squares fitting
for current_iteration in range(num_iterations):
fitter = LSFitter(data_windowed, model, data.good_freq, weighted_fit=True)
fitter = LSFitter(data_windowed, model, data.good_freq, weighted_fit=True, weight_range=weight_range)
fitter.fix_parameter(parameters_to_fix)
fitter.fit(exact_jacobian=True)

print(fitter.results)

# extract best-fit data for next loop.
Expand Down
Loading

0 comments on commit 77b0136

Please sign in to comment.