Skip to content

Commit

Permalink
Completed cpp Pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
Chrisa142857 committed Dec 8, 2023
1 parent ed04dc4 commit 48eaecd
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 104 deletions.
61 changes: 11 additions & 50 deletions cpp/flow_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@ torch::Tensor flow_2Dto3D(
bool skip_first
) {
sim_grad_z->to(device);
// std::vector<torch::Tensor> grad3d;
flow_2d = torch::cat({torch::zeros_like(flow_2d[0]).unsqueeze(0), flow_2d}); // [4 x Z x Y x X]

for (int64_t i = 0; i < flow_2d.size(1) - 1; ++i) {
// torch::Tensor yx_flow = flow_2d.index({torch::indexing::Slice(0, 2), i});
// torch::Tensor cellprob = flow_2d.index({3, i});
// torch::Tensor next_yx_flow = flow_2d.index({torch::indexing::Slice(1, 3), i + 1});
torch::Tensor pre_yx_flow;

if (i > 0) {
Expand All @@ -24,10 +20,7 @@ torch::Tensor flow_2Dto3D(
pre_yx_flow = pre_last_second;
}
if (i == 0 & skip_first) {continue;}
// std::vector<torch::jit::IValue> inputs({yx_flow.to(device), cellprob.to(device), pre_yx_flow.to(device), next_yx_flow.to(device)});
std::vector<torch::jit::IValue> inputs({flow_2d.index({3, i}).to(device), pre_yx_flow.to(device), flow_2d.index({torch::indexing::Slice(1, 3), i + 1}).to(device)});
// torch::Tensor dP = sim_grad_z->forward(inputs).toTensor().cpu();
// grad3d.push_back(dP);
flow_2d.index_put_({0, i}, sim_grad_z->forward(inputs).toTensor().cpu());
}
if (skip_first){
Expand All @@ -36,20 +29,14 @@ torch::Tensor flow_2Dto3D(
flow_2d = flow_2d.slice(1, 0, -1); // [4 x Z-1 x Y x X]
}
return flow_2d;

// torch::Tensor output = torch::stack(grad3d, 1);
// return output;
}

torch::Tensor index_flow(
// torch::jit::script::Module meshgrider,
torch::Tensor dP,
int64_t Lz, int64_t Ly, int64_t Lx, int64_t niter = 139
) {
std::vector<int64_t> shape = {Lz, Ly, Lx};
auto p_vec = torch::meshgrid({torch::arange(shape[0]), torch::arange(shape[1]), torch::arange(shape[2])}, "ij");
// auto meshes = meshgrider(std::vector<torch::jit::IValue>({torch::arange(shape[0]), torch::arange(shape[1]), torch::arange(shape[2])})).toTuple();
// std::vector<torch::Tensor> p_vec({meshes->elements()[0].toTensor(), meshes->elements()[1].toTensor(), meshes->elements()[2].toTensor()});
torch::Tensor p = torch::stack(p_vec, 0).to(torch::kFloat);
torch::Tensor inds = torch::nonzero(dP[0].abs() > 1e-3).to(torch::kLong);
auto z = inds.select(1, 0);
Expand Down Expand Up @@ -85,9 +72,6 @@ std::vector<torch::Tensor> expand_pt_unique(std::vector<torch::Tensor> pix, int6
torch::Tensor loc = torch::zeros_like(pix[0]).to(torch::kBool);
torch::Tensor sorted_idx = std::get<1>(torch::sort(unique_idx, true, -1, false));
loc.index_put_({sorted_idx.index({torch::cat({torch::tensor({0}), unique_idx.bincount().cumsum(0).slice(0, 0, -1)})})}, true);
// for (int64_t ui=0; ui<unique_eloc.size(0); ui++){
// loc[torch::where(unique_idx==ui)[0][0].item<int64_t>()] = true;
// }
return std::vector<torch::Tensor>({pix[0].index({loc}), pix[1].index({loc}), pix[2].index({loc})});
}

Expand All @@ -104,7 +88,6 @@ std::vector<torch::Tensor> flow_3DtoNIS(
std::vector<int64_t> rpads({zpad, rpad, rpad});
int64_t iter_num = 3; // Needs to be odd
int64_t dims = 3;
// std::vector<torch::Tensor> pflows(dims);
std::vector<torch::Tensor> edges(dims);
std::vector<torch::Tensor> seeds;
int64_t Lz = p.size(1);
Expand All @@ -118,25 +101,13 @@ std::vector<torch::Tensor> flow_3DtoNIS(
p.index_put_({i, torch::logical_not(iscell)}, inds[i].index({torch::logical_not(iscell)}).to(torch::kFloat));
}
for (int64_t i = 0; i < dims; ++i) {
// pflows[i] = p[i].flatten().clone().to(torch::kLong);
edges[i] = torch::arange(-0.5 - rpads[i], shape0[i] + 0.5 + rpads[i]);
shape.push_back(edges[i].numel()-1);
}
// print_size(p);
p = p.to(torch::kLong);
// p = torch::stack({p[0].flatten(), p[1].flatten(), p[2].flatten()}, 0);
p = torch::stack({p[0].view(-1), p[1].view(-1), p[2].view(-1)}, 0);
// print_size(p);
// print_with_time("3\n");
torch::Tensor fg = torch::zeros(shape, torch::kBool);
// print_with_time("3.1\n");
// std::cout<<p[0].max() + rpads[0]<< ",";
// std::cout<<p[1].max() + rpads[1]<< ",";
// std::cout<<p[2].max() + rpads[2]<< "\n";
// print_size(fg);
// std::cout << "| " << fg.max() << "\n";
fg.index_put_({p[0] + rpads[0], p[1] + rpads[1], p[2] + rpads[2]}, torch::ones(1, torch::kBool));
// std::cout << "| " << fg.max() << "\n";

torch::Tensor h = histogramdd(p.transpose(0, 1).to(torch::kDouble).detach().clone(), edges);
torch::Tensor pix = flow_3DtoSeed(std::vector<torch::jit::IValue>({h})).toTensor();
Expand All @@ -146,7 +117,6 @@ std::vector<torch::Tensor> flow_3DtoNIS(


torch::Tensor expand = torch::nonzero(torch::ones({3, 3, 3})).transpose(0, 1);
// print_size(expand);
std::vector<std::vector<torch::Tensor>> pix_copy(pix.size(0));

for (int64_t iter = 0; iter < iter_num; ++iter) {
Expand Down Expand Up @@ -199,36 +169,30 @@ std::vector<torch::Tensor> flow_3DtoNIS(
float big = fLz * fLy * fLx * 0.001;
print_with_time("Index masks, ");
std::cout << "Ultra big mask threshold: " << big << ". Start ";
// int64_t ilabel = 0;
for (int64_t k = 0; k < pix_copy.size(); ++k) {
if (pix_copy[k][0].size(0)==0) {
// std::cout<<" [ALL BG] | ";
remove_c += 1;
continue;
}
torch::Tensor is_fg = fg.index({pix_copy[k][0], pix_copy[k][1], pix_copy[k][2]});
torch::Tensor is_bg = torch::logical_not(is_fg);
if (is_bg.all().item<bool>()) {
// std::cout<<" [ALL BG] | ";
remove_c += 1;
continue;
}
std::vector<torch::Tensor> coord({pix_copy[k][0].index({is_fg}), pix_copy[k][1].index({is_fg}), pix_copy[k][2].index({is_fg})});
if (coord[0].size(0) > big) {
// std::cout << " [BIG one] " <<coord[0].size(0)<<" | ";
remove_c += 1;
continue;
}
ilabel += 1;
// print_size(coord[0]);
vols.push_back(coord[0].size(0));
std::vector<torch::Tensor> center({
(coord[0].max() + coord[0].min()) / 2,
(coord[1].max() + coord[1].min()) / 2,
(coord[2].max() + coord[2].min()) / 2
});
M.index_put_({coord[0], coord[1], coord[2]}, torch::tensor(ilabel, torch::kLong));
// coords.push_back(torch::cat({torch::stack(coord, 1), torch::ones(coord[0].size(0), 1, coord.options()) * (1 + k - remove_c)}, 1));
coords.push_back(torch::stack(coord, 1));
labels.push_back(ilabel);
centers.push_back(torch::stack(center, 0));
Expand All @@ -237,18 +201,15 @@ std::vector<torch::Tensor> flow_3DtoNIS(
}
}
std::cout << ", Done, removed " << remove_c << " ultra big or small masks, " << pix_copy.size() << " remain" << "\n";

// torch::Tensor M0 = M.index({pflows[0], pflows[1], pflows[2]});
// torch::Tensor coord_tensor = torch::cat(coords, 0);
// torch::Tensor label_tensor = torch::tensor(labels, torch::kLong);
// torch::Tensor center_tensor = torch::stack(centers);
// torch::Tensor vol_tensor = torch::tensor(vols, torch::kLong);

return {
M.index({p[0]+rpads[0], p[1]+rpads[1], p[2]+rpads[2]}).view(shape0),
torch::cat(coords, 0),
torch::tensor(labels, torch::kLong),
torch::tensor(vols, torch::kLong),
torch::stack(centers)
};
if (coords.size() == 0){
return {torch::zeros(0)};
} else {
return {
M.index({p[0]+rpads[0], p[1]+rpads[1], p[2]+rpads[2]}).view(shape0),
torch::cat(coords, 0),
torch::tensor(labels, torch::kLong),
torch::tensor(vols, torch::kLong),
torch::stack(centers)
};
}
}
82 changes: 28 additions & 54 deletions cpp/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ std::vector<torch::Tensor> nis_obtain(torch::jit::script::Module flow_3DtoSeed,
(flow3d.index({3, "..."})>cellprob_threshold),
ilabel, 20
);

save_tensor(nis_outputs[0], savefn+"_seg.zip");
save_tensor(nis_outputs[1], savefn+"_contour.zip");
save_tensor(nis_outputs[2], savefn+"_instance_label.zip");
save_tensor(nis_outputs[3], savefn+"_instance_volume.zip");
save_tensor(nis_outputs[4], savefn+"_instance_center.zip");
if (nis_outputs.size()>1) {
save_tensor(nis_outputs[0], savefn+"_seg.zip");
save_tensor(nis_outputs[1], savefn+"_contour.zip");
save_tensor(nis_outputs[2], savefn+"_instance_label.zip");
save_tensor(nis_outputs[3], savefn+"_instance_volume.zip");
save_tensor(nis_outputs[4], savefn+"_instance_center.zip");
}
return nis_outputs;
}

Expand Down Expand Up @@ -236,19 +237,12 @@ int main(int argc, const char* argv[]) {
torch::Tensor last_flow = gpu_outputs[6];
// torch::Tensor masks;
std::vector<torch::Tensor> last_first_masks;
bool pre_has_nis;
for (int64_t i = chunk_depth; i < img_fns.size(); i+=chunk_depth) {
/*
Follow the 3D flow to obtain NIS (CPU)
*/
// torch::Tensor dP = flow3d.index({torch::indexing::Slice(torch::indexing::None, 3), "..."});
// torch::Tensor cellprob = flow3d.index({3, "..."});
// torch::Tensor cp_mask = cellprob > cellprob_threshold;
// if (cp_mask.any().item<bool>()) {
nis_obtainer = std::async(std::launch::async, nis_obtain, flow_3DtoSeed, gpu_outputs[0], cellprob_threshold, old_instance_n, h5fn+"_zmin"+std::to_string(zmin));
// } else {
// print_with_time("No instance, probability map is all zero, continue");
// // continue;
// }
gpu_outputs.clear();
gpu_outputs = gpu_process(
i,
Expand All @@ -263,37 +257,35 @@ int main(int argc, const char* argv[]) {
pre_last_second,
device
);
if (i > chunk_depth){
if (i > chunk_depth & pre_has_nis){
pre_last_mask = last_first_masks[0];
}
// if (cp_mask.any().item<bool>()) {
nis_outputs = nis_obtainer.get();
/*
Save NIS results (IO)
Get last and first slice of output
*/
// Save the chunk to H5 database
// print_with_time("Save NIS results to H5 database\n");

// if (i > chunk_depth){
// masks = nis_saver.get();
// }
hsize_t zmax = zmin + nis_outputs[0].size(0);
print_with_time("zmin: ");
std::cout<<zmin<<", zmax: "<<zmax<<"\n";
print_with_time("whole_brain_shape: ");
std::cout<<whole_brain_shape<<"\n";
// nis_saver = std::async(std::launch::async, save_h5data, dsetlist, nis_outputs, old_instance_n, old_contour_n, zmin, zmax, whole_brain_shape);
last_first_masks = save_h5data(h5fn, nis_outputs, old_instance_n, old_contour_n, zmin, zmax, whole_brain_shape);
first_mask = last_first_masks[1];
zmin += nis_outputs[0].size(0);
old_instance_n += nis_outputs[2].size(0);
old_contour_n += nis_outputs[1].size(0);
if (nis_outputs.size() > 1){
hsize_t zmax = zmin + nis_outputs[0].size(0);
print_with_time("zmin: ");
std::cout<<zmin<<", zmax: "<<zmax<<"\n";
print_with_time("whole_brain_shape: ");
std::cout<<whole_brain_shape<<"\n";
// nis_saver = std::async(std::launch::async, save_h5data, dsetlist, nis_outputs, old_instance_n, old_contour_n, zmin, zmax, whole_brain_shape);
last_first_masks = save_h5data(h5fn, nis_outputs, old_instance_n, old_contour_n, zmin, zmax, whole_brain_shape);
first_mask = last_first_masks[1];
zmin += nis_outputs[0].size(0);
old_instance_n += nis_outputs[2].size(0);
old_contour_n += nis_outputs[1].size(0);
pre_has_nis = true;
} else {
pre_has_nis = false;
}
nis_outputs.clear();
// }
/*
Run GNN to stitch the gap (GPU)
*/
if (i > chunk_depth){
if (i > chunk_depth & pre_has_nis){
remap_all = stitch_process(
remap_all,
&gnn_message_passing,
Expand All @@ -319,26 +311,8 @@ int main(int argc, const char* argv[]) {
first_flow = gpu_outputs[5];
last_flow = gpu_outputs[6];
}
/*
Follow the 3D flow to obtain NIS (CPU)
*/
// torch::Tensor dP = flow3d.index({torch::indexing::Slice(torch::indexing::None, 3), "..."});
// torch::Tensor cellprob = flow3d.index({3, "..."});
// torch::Tensor cp_mask = cellprob > cellprob_threshold;
// if (cp_mask.any().item<bool>()) {
// nis_obtainer = std::async(std::launch::async, nis_obtain, flow_3DtoSeed, dP, cp_mask);

nis_obtainer = std::async(std::launch::async, nis_obtain, flow_3DtoSeed, gpu_outputs[0], cellprob_threshold, old_instance_n, h5fn+"_zmin"+std::to_string(zmin));
nis_outputs = nis_obtainer.get();
/*
Save NIS results (IO)
*/
// Save the chunk to H5 database
// print_with_time("Save NIS results to H5 database\n");
hsize_t zmax = zmin + nis_outputs[0].size(0);
last_first_masks = save_h5data(h5fn, nis_outputs, old_instance_n, old_contour_n, zmin, zmax, whole_brain_shape);
// } else {
// print_with_time("No instance, probability map is all zero, continue");
// // continue;
// }
std::cout << "ok\n";
}
Loading

0 comments on commit 48eaecd

Please sign in to comment.