diff --git a/lgmcts/scripts/data_generation/gen_strdiff.py b/lgmcts/scripts/data_generation/gen_strdiff.py index 49fff96..97e8b9c 100644 --- a/lgmcts/scripts/data_generation/gen_strdiff.py +++ b/lgmcts/scripts/data_generation/gen_strdiff.py @@ -110,7 +110,8 @@ def _generate_data_for_one_task( # normalize depth to fit in structFormer depth_tensor = depth_tensor * 20.0 # depth_min & depth_max - depth_min = np.min(depth_tensor) * np.ones([2,], dtype=np.float32) + depth_min = np.min(depth_tensor) * np.ones([1,], dtype=np.float32) + depth_max = np.max(depth_tensor) * np.ones([1,], dtype=np.float32) f.create_dataset("depth_min", data=depth_min) f.create_dataset("depth_max", data=depth_max) # normalize depth