Skip to content

Commit

Permalink
fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ajmejia committed Apr 11, 2024
1 parent df08e18 commit 6368ff9
Showing 1 changed file with 40 additions and 23 deletions.
63 changes: 40 additions & 23 deletions tests/functions/test_imageMethod.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,53 @@

import os
import numpy as np
import pytest

from lvmdrp import path
from lvmdrp.core.image import Image
from lvmdrp.core.image import Image, loadImage
from lvmdrp.functions import imageMethod


# def test_fix_pixel_shifts_noshift(make_fits):
# make_fits(mjd=61231, cameras=['b1'], expnum=3, leak=False, shift_rows=[])
# make_fits(mjd=61231, cameras=['b1'], expnum=4, leak=False, shift_rows=[])
# rpath = path.full("lvm_raw", hemi="s", camspec="b1", mjd=61231, expnum=3)
# ipath = path.full("lvm_raw", hemi="s", camspec="b1", mjd=61231, expnum=4)
@pytest.fixture
def mask_2d():
mask_2d = Image(data=np.ones((4080, 3*4120), dtype=int))
yield mask_2d

# image_ori = Image()
# image_ori.loadFitsData(ipath)
# shift_columns, image_fixed = imageMethod.fix_pixel_shifts(in_image=ipath, ref_image=rpath)
def test_fix_pixel_shifts_noshift(make_fits, mask_2d):
make_fits(mjd=61231, cameras=['b1', 'r1', 'z1'], expnum=3, leak=False, shift_rows=[])
make_fits(mjd=61231, cameras=['b1', 'r1', 'z1'], expnum=4, leak=False, shift_rows=[])
rpaths = sorted(path.expand("lvm_raw", hemi="s", camspec="?1", mjd=61231, expnum=3))
ipaths = sorted(path.expand("lvm_raw", hemi="s", camspec="?1", mjd=61231, expnum=4))
pixshift_path = path.full("lvm_anc", drpver="test", imagetype="pixshift", tileid=11111, mjd=61231, camera="sp1", expnum=4, kind="")
mask_2d_path = path.full("lvm_anc", drpver="test", imagetype="mask2d", tileid=11111, mjd=61231, camera="sp1", expnum=0, kind="")
os.makedirs(os.path.dirname(pixshift_path), exist_ok=True)

# assert (shift_columns == 0).all()
# assert (image_fixed._data == image_ori._data).all()
mask_2d.writeFitsData(mask_2d_path)

images_ori = [loadImage(rpath) for rpath in rpaths]
shift_columns, corrs, images_fixed = imageMethod.fix_pixel_shifts(in_images=ipaths, out_pixshift=pixshift_path, ref_images=rpaths, in_mask=mask_2d_path)

# def test_fix_pixel_shifts(make_fits):
# make_fits(mjd=61231, cameras=['b1'], expnum=5, leak=False, shift_rows=[])
# make_fits(mjd=61231, cameras=['b1'], expnum=6, leak=False, shift_rows=[1500])
# rpath = path.full("lvm_raw", hemi="s", camspec="b1", mjd=61231, expnum=5)
# ipath = path.full("lvm_raw", hemi="s", camspec="b1", mjd=61231, expnum=6)
for image_fixed, image_ori in zip(images_fixed, images_ori):
assert (shift_columns == 0).all()
assert (image_fixed._data == image_ori._data).all()

# image_ori = Image()
# image_ori.loadFitsData(rpath)
# shift_columns, image_fixed = imageMethod.fix_pixel_shifts(in_image=ipath, ref_image=rpath)
# expected_shifts = np.zeros_like(shift_columns)
# expected_shifts[1500:] = 2

# assert (shift_columns == expected_shifts).all()
# assert (image_fixed._data == image_ori._data).all()
def test_fix_pixel_shifts(make_fits, mask_2d):
make_fits(mjd=61231, cameras=['b1', 'r1', 'z1'], expnum=5, leak=False, shift_rows=[])
make_fits(mjd=61231, cameras=['b1', 'r1', 'z1'], expnum=6, leak=False, shift_rows=[1500])
rpaths = sorted(path.expand("lvm_raw", hemi="s", camspec="?1", mjd=61231, expnum=5))
ipaths = sorted(path.expand("lvm_raw", hemi="s", camspec="?1", mjd=61231, expnum=6))
pixshift_path = path.full("lvm_anc", drpver="test", imagetype="pixshift", tileid=11111, mjd=61231, camera="sp1", expnum=4, kind="")
mask_2d_path = path.full("lvm_anc", drpver="test", imagetype="mask2d", tileid=11111, mjd=61231, camera="sp1", expnum=0, kind="")
os.makedirs(os.path.dirname(pixshift_path), exist_ok=True)

mask_2d.writeFitsData(mask_2d_path)

images_ori = [loadImage(rpath) for rpath in rpaths]
shift_columns, corrs, images_fixed = imageMethod.fix_pixel_shifts(in_images=ipaths, out_pixshift=pixshift_path, ref_images=rpaths, in_mask=mask_2d_path)
expected_shifts = np.zeros_like(shift_columns)
expected_shifts[1500:] = 2

assert (shift_columns == expected_shifts).all()
for image_fixed, image_ori in zip(images_fixed, images_ori):
assert (image_fixed._data == image_ori._data).all()

0 comments on commit 6368ff9

Please sign in to comment.