-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathshow_data.py
116 lines (91 loc) · 4.56 KB
/
show_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import os
import pickle
from typing import Tuple
import dask.array as da
import neuroglancer
import numpy as np
from dask.array import clip
from dask_image.ndfilters import gaussian
from networkx import connected_components, convert_node_labels_to_integers, subgraph
from neuroglancer import CoordinateSpace, LocalVolume, Viewer, SegmentationLayer
class SkeletonSource(neuroglancer.skeleton.SkeletonSource):
def __init__(self, dimensions, skels):
super().__init__(dimensions)
self.skels = skels
self.id_to_comp = {
skels.nodes[next(iter(c))]["id"]: c for c in connected_components(skels)
}
def get_skeleton(self, i):
comp = self.id_to_comp.get(i)
if not comp:
return None
comp = convert_node_labels_to_integers(subgraph(self.skels, comp))
vertices = [comp.nodes[v]["index_position"] for v in range(len(comp))]
return neuroglancer.skeleton.Skeleton(vertex_positions=vertices, edges=list(comp.edges))
# Coordinate spaces
COORDS = {
"standard": CoordinateSpace(names=['x', 'y', 'z'], units=['nm', 'nm', 'nm'], scales=[9, 9, 20]),
"standard_c": CoordinateSpace(names=["x", "y", "z", "c^"], units=["nm", "nm", "nm", ""], scales=[9, 9, 20, 1]),
"liconn": CoordinateSpace(names=['x', 'y', 'z'], units=['nm', 'nm', 'nm'], scales=[9, 9, 12]),
"liconn_c": CoordinateSpace(names=["x", "y", "z", "c^"], units=["nm", "nm", "nm", ""], scales=[9, 9, 12, 1])
}
def load_data(base_path: str) -> Tuple[da.Array, da.Array, dict]:
"""Load image, segmentation, and skeleton data."""
seg = da.from_zarr(os.path.join(base_path, "data.zarr/seg")).astype(np.uint32)
img = da.from_zarr(os.path.join(base_path, "data.zarr/img"))
with open(os.path.join(base_path, "skeleton.pkl"), 'rb') as f:
skels = pickle.load(f)
return img, seg, skels
def add_image_layer(s, name: str, img: da.Array, c_res: CoordinateSpace, multichannel=False):
"""Add an image layer to the viewer."""
layer = LocalVolume(img[:, :, :, :3] if multichannel else img, dimensions=c_res)
shader = """
void main() {
emitRGB(vec3(
toNormalized(getDataValue(0)),
toNormalized(getDataValue(1)),
toNormalized(getDataValue(2))
));
}
""" if multichannel else None
s.layers.append(name=f'img_{name}', layer=layer, shader=shader)
if multichannel:
blurred_img = (clip(gaussian(img[:, :, :, :3], sigma=(2, 2, 2, 0)), 112, 144) - 112) * 8
s.layers.append(name=f'img_blurred_{name}', layer=LocalVolume(blurred_img, dimensions=c_res), shader=shader)
def add_segmentation_layer(s, name: str, seg: da.Array, skels: dict, res: CoordinateSpace):
"""Add a segmentation layer to the viewer."""
s.layers.append(
name=f'seg_{name}',
layer=SegmentationLayer(
source=[LocalVolume(seg, dimensions=res, volume_type="segmentation"), SkeletonSource(res, skels)],
skeleton_shader='void main() { emitRGB(colormapJet(affinity)); }',
mesh_silhouette_rendering=2.0
)
)
def create_viewer(base_path: str, port: int) -> Viewer:
"""Create and configure the Neuroglancer viewer."""
neuroglancer.set_server_bind_address('localhost', port)
viewer = Viewer()
with viewer.txn() as s:
for name in sorted(os.listdir(base_path)):
print(f"Adding {name}")
data_path = os.path.join(base_path, name, "val", os.listdir(os.path.join(base_path, name, "val"))[0])
img, seg, skels = load_data(data_path)
coord_space = COORDS["liconn_c"] if name == "liconn" else COORDS["standard_c"]
add_image_layer(s, name, img, coord_space, multichannel=(name == "multichannel"))
seg_space = COORDS["liconn"] if name == "liconn" else COORDS["standard"]
add_segmentation_layer(s, name, seg, skels, seg_space)
for layer in s.layers[2:]:
layer.visible = False
print("If on a remote server, remember port forwarding. Meshes may take time to load.")
return viewer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Neuroglancer Viewer for NISB project")
parser.add_argument("--base_path", type=str, default="/cajal/nvmescratch/projects/NISB/",
help="Base path for NISB data")
parser.add_argument("--port", type=int, default=8085, help="Port to run the viewer")
args = parser.parse_args()
viewer = create_viewer(args.base_path, args.port)
print(viewer.get_viewer_url())
input("Press Enter to quit")