Skip to content

Commit

Permalink
implementing by-hand shift locations
Browse files Browse the repository at this point in the history
  • Loading branch information
ajmejia committed Apr 11, 2024
1 parent a1f0e37 commit df08e18
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 28 deletions.
61 changes: 35 additions & 26 deletions python/lvmdrp/functions/imageMethod.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def select_lines_2d(in_images, out_mask, in_cent_traces, in_waves, lines_list=No

def fix_pixel_shifts(in_images, out_pixshift, ref_images, in_mask,
max_shift=10, threshold_spikes=0.6, flat_spikes=11,
fill_gaps=20, dry_run=False, undo_correction=False, display_plots=False):
fill_gaps=20, shift_rows=None, dry_run=False, undo_correction=False, display_plots=False):
"""Corrects pixel shifts in raw frames based on reference frames and a selection of spectral regions
Given a set of raw frames, reference frames and a mask, this function corrects pixel shifts
Expand Down Expand Up @@ -551,38 +551,47 @@ def fix_pixel_shifts(in_images, out_pixshift, ref_images, in_mask,
# load input images into output array to apply corrections if needed
images_out = [loadImage(in_image) for in_image in in_images]

# calculate pixel shifts
log.info("running row-by-row cross-correlation")
shifts, corrs = [], []
for irow in range(rdata.shape[0]):
cimg_row = cdata[irow]
rimg_row = rdata[irow]
if numpy.all(cimg_row == 0) or numpy.all(rimg_row == 0):
shifts.append(0)
corrs.append(0)
continue
# calculate pixel shifts or use provided ones
if shift_rows is None:
log.info("running row-by-row cross-correlation")
shifts, corrs = [], []
for irow in range(rdata.shape[0]):
cimg_row = cdata[irow]
rimg_row = rdata[irow]
if numpy.all(cimg_row == 0) or numpy.all(rimg_row == 0):
shifts.append(0)
corrs.append(0)
continue

shift = signal.correlation_lags(cimg_row.size, rimg_row.size, mode="same")
corr = signal.correlate(cimg_row, rimg_row, mode="same")
shift = signal.correlation_lags(cimg_row.size, rimg_row.size, mode="same")
corr = signal.correlate(cimg_row, rimg_row, mode="same")

mask = (numpy.abs(shift) <= max_shift)
shift = shift[mask]
corr = corr[mask]
mask = (numpy.abs(shift) <= max_shift)
shift = shift[mask]
corr = corr[mask]

max_corr = numpy.argmax(corr)
shifts.append(shift[max_corr])
corrs.append(corr[max_corr])
shifts = numpy.asarray(shifts)
corrs = numpy.asarray(corrs)
max_corr = numpy.argmax(corr)
shifts.append(shift[max_corr])
corrs.append(corr[max_corr])
shifts = numpy.asarray(shifts)
corrs = numpy.asarray(corrs)

raw_shifts = copy(shifts)
shifts = _remove_spikes(shifts, width=flat_spikes, threshold=threshold_spikes)
shifts = _fillin_valleys(shifts, width=fill_gaps)
shifts = _no_stepdowns(shifts)
raw_shifts = copy(shifts)
shifts = _remove_spikes(shifts, width=flat_spikes, threshold=threshold_spikes)
shifts = _fillin_valleys(shifts, width=fill_gaps)
shifts = _no_stepdowns(shifts)
else:
log.info("using user provided pixel shifts")
shifts = numpy.zeros(cdata.shape[0])
for irow in shift_rows:
shifts[irow:] += 2
raw_shifts = copy(shifts)
corrs = numpy.zeros_like(shifts)

apply_shift = numpy.any(numpy.abs(shifts)>0)
if apply_shift:
log.info(f"found {numpy.sum(numpy.abs(shifts)>0)} rows with pixel shifts")
shifted_rows = numpy.where(numpy.gradient(shifts) > 0)[0][1::2].tolist()
log.info(f"applying shifts to {shifted_rows = } ({numpy.sum(numpy.abs(shifts)>0)}) rows")
for image_out, in_image, out_image, ori_image in zip(images_out, in_images, out_images, ori_images):
image = copy(image_out)
mjd = image._header.get("SMJD", image._header["MJD"])
Expand Down
11 changes: 9 additions & 2 deletions python/lvmdrp/functions/run_calseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def messup_frame(mjd, expnum, spec="1", shifts=[1500, 2000, 3500], shift_size=-2

def fix_raw_pixel_shifts(mjd, expnums=None, ref_expnums=None, specs="123",
y_widths=5, wave_list=None, wave_widths=0.6*5, max_shift=10, flat_spikes=11,
threshold_spikes=np.inf, create_mask_always=False, dry_run=False,
threshold_spikes=np.inf, shift_rows=None, create_mask_always=False, dry_run=False,
undo_corrections=False, display_plots=False):
"""Attempts to fix pixel shifts in a list of raw frames
Expand Down Expand Up @@ -311,6 +311,8 @@ def fix_raw_pixel_shifts(mjd, expnums=None, ref_expnums=None, specs="123",
Number of flat spikes, by default 11
threshold_spikes : float
Threshold for spikes, by default np.inf
shift_rows : dict
Rows to shift, by default None
create_mask_always : bool
Create mask always, by default False
dry_run : bool
Expand All @@ -328,6 +330,11 @@ def fix_raw_pixel_shifts(mjd, expnums=None, ref_expnums=None, specs="123",
else:
raise ValueError("no valid reference exposure number given")

if shift_rows is None:
shift_rows = {}
elif not isinstance(shift_rows, dict):
raise ValueError("shift_rows must be a dictionary with keys (spec, expnum) and values a list of rows to shift")

frames, _ = get_sequence_metadata(mjd, target_mjd=None, expnums=expnums)
masters_mjd = get_master_mjd(sci_mjd=mjd)
masters_path = os.path.join(MASTERS_DIR, str(masters_mjd))
Expand Down Expand Up @@ -360,7 +367,7 @@ def fix_raw_pixel_shifts(mjd, expnums=None, ref_expnums=None, specs="123",

image_tasks.fix_pixel_shifts(in_images=rframe_paths, out_pixshift=pixshift_path,
ref_images=cframe_paths, in_mask=mask_2d_path, flat_spikes=flat_spikes,
threshold_spikes=threshold_spikes, max_shift=max_shift,
threshold_spikes=threshold_spikes, max_shift=max_shift, shift_rows=shift_rows.get((spec, expnum), None),
dry_run=dry_run, undo_correction=undo_corrections, display_plots=display_plots)


Expand Down

0 comments on commit df08e18

Please sign in to comment.