Skip to content

Commit

Permalink
Debug NIS profile (label, volume, center, coordinate) obtained from cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
Chrisa142857 committed Mar 27, 2024
1 parent 914db67 commit 0e1a9d5
Show file tree
Hide file tree
Showing 5 changed files with 327 additions and 264 deletions.
56 changes: 51 additions & 5 deletions cpp/flow_op.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,49 @@
#include "flow_op.h"

std::vector<torch::Tensor> get_large_fg_coord(torch::Tensor seg) {
auto options = torch::TensorOptions().dtype(torch::kInt64);
auto seg_shape = seg.sizes();
std::vector<torch::Tensor> meshgrid_tensors;

for (int dimi = 0; dimi < seg_shape.size(); ++dimi) {
meshgrid_tensors.push_back(torch::arange(seg_shape[dimi], options));
}

auto meshgrid_output = torch::meshgrid(meshgrid_tensors);
torch::Tensor z = meshgrid_output[0];
torch::Tensor y = meshgrid_output[1];
torch::Tensor x = meshgrid_output[2];

torch::Tensor nis_mask = seg > 0;
auto unique_output = at::_unique2(seg.masked_select(nis_mask), true, false, true);
torch::Tensor label = std::get<0>(unique_output);
torch::Tensor vol = std::get<2>(unique_output);

torch::Tensor splits = vol.cumsum(0);
z = z.masked_select(nis_mask);
y = y.masked_select(nis_mask);
x = x.masked_select(nis_mask);

auto sorted_nis = seg.masked_select(nis_mask).argsort();
z = z.index_select(0, sorted_nis);
y = y.index_select(0, sorted_nis);
x = x.index_select(0, sorted_nis);

auto pt = torch::stack({z, y, x}, -1);
auto pt_splits = torch::tensor_split(pt.cpu(), splits);
pt_splits.pop_back();

std::vector<torch::Tensor> ct;
for (const auto& p : pt_splits) {
ct.push_back((std::get<0>(p.max(0)) + std::get<0>(p.min(0))) / 2);
}
std::vector<torch::Tensor> outputs;
outputs.push_back(torch::stack(ct));
outputs.push_back(pt);
outputs.push_back(label);
outputs.push_back(vol);
return outputs;
}

torch::Tensor flow_2Dto3D(
torch::Tensor flow_2d, // [3 x Z x Y x X]
Expand Down Expand Up @@ -204,12 +248,14 @@ std::vector<torch::Tensor> flow_3DtoNIS(
if (centers.size() == 0){
return {torch::zeros(0)};
} else {
M = M.index({p[0]+rpads[0], p[1]+rpads[1], p[2]+rpads[2]}).view(shape0);
std::vector<torch::Tensor> nis_profile = get_large_fg_coord(M);
return {
M.index({p[0]+rpads[0], p[1]+rpads[1], p[2]+rpads[2]}).view(shape0),
// torch::cat(coords, 0),
torch::tensor(labels, torch::kLong),
// torch::tensor(vols, torch::kLong),
torch::stack(centers)
M,
nis_profile[0],
nis_profile[1],
nis_profile[2],
nis_profile[3]
};
}
}
11 changes: 6 additions & 5 deletions cpp/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ std::vector<torch::Tensor> nis_obtain(torch::jit::script::Module flow_3DtoSeed,
);
if (nis_outputs.size()>1) {
save_tensor(nis_outputs[0], savefn+"_seg.zip");
// save_tensor(nis_outputs[1], savefn+"_contour.zip");
save_tensor(nis_outputs[1], savefn+"_instance_label.zip");
torch::Tensor vols = nis_outputs[0].reshape(-1).bincount();
save_tensor(vols, savefn+"_instance_volume.zip");
save_tensor(nis_outputs[2], savefn+"_instance_center.zip");
save_tensor(nis_outputs[1], savefn+"_instance_center.zip");
save_tensor(nis_outputs[2], savefn+"_instance_coordinate.zip");
save_tensor(nis_outputs[3], savefn+"_instance_label.zip");
// torch::Tensor vols = nis_outputs[0].reshape(-1).bincount();
// save_tensor(vols, savefn+"_instance_volume.zip");
save_tensor(nis_outputs[4], savefn+"_instance_volume.zip");
}
return nis_outputs;
}
Expand Down
11 changes: 7 additions & 4 deletions nis_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ def nis_mask(pair_tag, brain_tag, img_r, P_tag='P4'):
z_ratio = 4/2.5
save_root = f'/cajal/ACMUSERS/ziquanw/Lightsheet/renders/{P_tag}/{pair_tag}'
tgt_z = 718+24
tgt_x = 3000
tgt_y = 6000
tgt_x = 0.25
tgt_y = 0.5
width = 200
z_range = [int((tgt_z-24)//z_ratio), int((tgt_z+24)//z_ratio)]
y_range = [tgt_y-100, tgt_y+100]
x_range = [tgt_x-100, tgt_x+100]
# stack_shape = [60,200,200]
seg = torch.load(f'/cajal/ACMUSERS/ziquanw/Lightsheet/results/{P_tag}/{pair_tag}/{brain_tag}/{brain_tag}_NIScpp_results_zmin{tgt_z-24}_seg.zip')
tgt_x = int(seg.shape[1]*tgt_x)
tgt_y = int(seg.shape[2]*tgt_y)
y_range = [tgt_y-width, tgt_y+width]
x_range = [tgt_x-width, tgt_x+width]
mask_img = seg[:, y_range[0]:y_range[1]+1,x_range[0]:x_range[1]+1]
stack_seg_onfg = mask_img>0
mask_img[stack_seg_onfg] = mask_img[stack_seg_onfg]-mask_img[stack_seg_onfg].min()+1
Expand Down
12 changes: 7 additions & 5 deletions qc_vol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime

res_r = '/cajal/ACMUSERS/ziquanw/Lightsheet/results/P4'
for pair_tag in os.listdir(res_r):
for pair_tag in os.listdir(res_r)[:8]:
if not pair_tag.startswith('pair'): continue
for brain_tag in os.listdir(f'{res_r}/{pair_tag}'):
for seg_name in os.listdir(f'{res_r}/{pair_tag}/{brain_tag}'):
Expand All @@ -12,11 +12,11 @@
# seg_path='/cajal/ACMUSERS/ziquanw/Lightsheet/results/P4/pair10/L64D804P3/L64D804P3_NIScpp_results_zmin718_seg.zip'
label_path = seg_path.replace('seg.zip', 'instance_label.zip')
vol_path = seg_path.replace('seg.zip', 'instance_volume.zip')
ct_path = seg_path.replace('seg.zip', 'instance_center.zip')
# ct_path = seg_path.replace('seg.zip', 'instance_center.zip')
# pt_path = seg_path.replace('seg.zip', 'contour.zip')
seg=torch.load(seg_path)
# vol = torch.load(vol_path)
label = torch.load(label_path)
# label = torch.load(label_path)
# ct = torch.load(ct_path)
# pt = torch.load(pt_path)
# vols = []
Expand All @@ -27,11 +27,13 @@
# indecies = [[],[],[]]
# labels = []
# splits = []
seg_bincount = seg.reshape(-1).bincount()
seg_bincount = seg[seg>0].reshape(-1).bincount()
print(datetime.now(), 'Geting vols', seg_bincount.shape)
vols = seg_bincount[label]
vols = seg_bincount
print(datetime.now(), 'Saving vols', vols.shape)
torch.save(vols, vol_path)
label = seg[seg>0].unique()
torch.save(label, label_path)
# instance_masks = torch.cat(instance_masks)#.bool()
# torch.save(instance_masks, seg_path.replace('seg.zip', 'instance_mask.zip'))
# print(vols)
Loading

0 comments on commit 0e1a9d5

Please sign in to comment.