diff --git a/python/lvmdrp/functions/imageMethod.py b/python/lvmdrp/functions/imageMethod.py index decef6fb..b28b18ab 100644 --- a/python/lvmdrp/functions/imageMethod.py +++ b/python/lvmdrp/functions/imageMethod.py @@ -478,7 +478,7 @@ def select_lines_2d(in_images, out_mask, in_cent_traces, in_waves, lines_list=No def fix_pixel_shifts(in_images, out_pixshift, ref_images, in_mask, max_shift=10, threshold_spikes=0.6, flat_spikes=11, - fill_gaps=20, dry_run=False, undo_correction=False, display_plots=False): + fill_gaps=20, shift_rows=None, dry_run=False, undo_correction=False, display_plots=False): """Corrects pixel shifts in raw frames based on reference frames and a selection of spectral regions Given a set of raw frames, reference frames and a mask, this function corrects pixel shifts @@ -551,38 +551,47 @@ def fix_pixel_shifts(in_images, out_pixshift, ref_images, in_mask, # load input images into output array to apply corrections if needed images_out = [loadImage(in_image) for in_image in in_images] - # calculate pixel shifts - log.info("running row-by-row cross-correlation") - shifts, corrs = [], [] - for irow in range(rdata.shape[0]): - cimg_row = cdata[irow] - rimg_row = rdata[irow] - if numpy.all(cimg_row == 0) or numpy.all(rimg_row == 0): - shifts.append(0) - corrs.append(0) - continue + # calculate pixel shifts or use provided ones + if shift_rows is None: + log.info("running row-by-row cross-correlation") + shifts, corrs = [], [] + for irow in range(rdata.shape[0]): + cimg_row = cdata[irow] + rimg_row = rdata[irow] + if numpy.all(cimg_row == 0) or numpy.all(rimg_row == 0): + shifts.append(0) + corrs.append(0) + continue - shift = signal.correlation_lags(cimg_row.size, rimg_row.size, mode="same") - corr = signal.correlate(cimg_row, rimg_row, mode="same") + shift = signal.correlation_lags(cimg_row.size, rimg_row.size, mode="same") + corr = signal.correlate(cimg_row, rimg_row, mode="same") - mask = (numpy.abs(shift) <= max_shift) - shift = shift[mask] - corr = corr[mask] + mask = (numpy.abs(shift) <= max_shift) + shift = shift[mask] + corr = corr[mask] - max_corr = numpy.argmax(corr) - shifts.append(shift[max_corr]) - corrs.append(corr[max_corr]) - shifts = numpy.asarray(shifts) - corrs = numpy.asarray(corrs) + max_corr = numpy.argmax(corr) + shifts.append(shift[max_corr]) + corrs.append(corr[max_corr]) + shifts = numpy.asarray(shifts) + corrs = numpy.asarray(corrs) - raw_shifts = copy(shifts) - shifts = _remove_spikes(shifts, width=flat_spikes, threshold=threshold_spikes) - shifts = _fillin_valleys(shifts, width=fill_gaps) - shifts = _no_stepdowns(shifts) + raw_shifts = copy(shifts) + shifts = _remove_spikes(shifts, width=flat_spikes, threshold=threshold_spikes) + shifts = _fillin_valleys(shifts, width=fill_gaps) + shifts = _no_stepdowns(shifts) + else: + log.info("using user provided pixel shifts") + shifts = numpy.zeros(cdata.shape[0]) + for irow in shift_rows: + shifts[irow:] += 2 + raw_shifts = copy(shifts) + corrs = numpy.zeros_like(shifts) apply_shift = numpy.any(numpy.abs(shifts)>0) if apply_shift: - log.info(f"found {numpy.sum(numpy.abs(shifts)>0)} rows with pixel shifts") + shifted_rows = numpy.where(numpy.gradient(shifts) > 0)[0][1::2].tolist() + log.info(f"applying shifts to {shifted_rows = } ({numpy.sum(numpy.abs(shifts)>0)}) rows") for image_out, in_image, out_image, ori_image in zip(images_out, in_images, out_images, ori_images): image = copy(image_out) mjd = image._header.get("SMJD", image._header["MJD"]) diff --git a/python/lvmdrp/functions/run_calseq.py b/python/lvmdrp/functions/run_calseq.py index dd1bf44d..72a6613f 100644 --- a/python/lvmdrp/functions/run_calseq.py +++ b/python/lvmdrp/functions/run_calseq.py @@ -281,7 +281,7 @@ def messup_frame(mjd, expnum, spec="1", shifts=[1500, 2000, 3500], shift_size=-2 def fix_raw_pixel_shifts(mjd, expnums=None, ref_expnums=None, specs="123", y_widths=5, wave_list=None, wave_widths=0.6*5, max_shift=10, flat_spikes=11, - threshold_spikes=np.inf, create_mask_always=False, dry_run=False, + threshold_spikes=np.inf, shift_rows=None, create_mask_always=False, dry_run=False, undo_corrections=False, display_plots=False): """Attempts to fix pixel shifts in a list of raw frames @@ -311,6 +311,8 @@ def fix_raw_pixel_shifts(mjd, expnums=None, ref_expnums=None, specs="123", Number of flat spikes, by default 11 threshold_spikes : float Threshold for spikes, by default np.inf + shift_rows : dict + Rows to shift, by default None create_mask_always : bool Create mask always, by default False dry_run : bool @@ -328,6 +330,11 @@ def fix_raw_pixel_shifts(mjd, expnums=None, ref_expnums=None, specs="123", else: raise ValueError("no valid reference exposure number given") + if shift_rows is None: + shift_rows = {} + elif not isinstance(shift_rows, dict): + raise ValueError("shift_rows must be a dictionary with keys (spec, expnum) and values a list of rows to shift") + frames, _ = get_sequence_metadata(mjd, target_mjd=None, expnums=expnums) masters_mjd = get_master_mjd(sci_mjd=mjd) masters_path = os.path.join(MASTERS_DIR, str(masters_mjd)) @@ -360,7 +367,7 @@ def fix_raw_pixel_shifts(mjd, expnums=None, ref_expnums=None, specs="123", image_tasks.fix_pixel_shifts(in_images=rframe_paths, out_pixshift=pixshift_path, ref_images=cframe_paths, in_mask=mask_2d_path, flat_spikes=flat_spikes, - threshold_spikes=threshold_spikes, max_shift=max_shift, + threshold_spikes=threshold_spikes, max_shift=max_shift, shift_rows=shift_rows.get((spec, expnum), None), dry_run=dry_run, undo_correction=undo_corrections, display_plots=display_plots)