Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Installation with pip? #5

Open
alberthli opened this issue Oct 13, 2023 · 4 comments
Open

Installation with pip? #5

alberthli opened this issue Oct 13, 2023 · 4 comments

Comments

@alberthli
Copy link

Hi, great work on this package!

I'm looking to use this package in conjunction with other dependencies in a larger project and have no experience with nix. Since there are multiple top-level packages, it's not clear to me how to correctly install jaxngp. Further, when I try to install from a subdirectory (e.g., trying to install only jax-tcnn), I'm unable to do so. For example, when trying to install only the jax-tcnn subpackage using pip install "git+https://github.com/blurgyy/jaxngp.git#egg=jax-tcnn&subdirectory=deps/jax-tcnn", I get the error

Building wheels for collected packages: jax-tcnn
  Building wheel for jax-tcnn (pyproject.toml) ... error
  error: subprocess-exited-with-error
  
  × Building wheel for jax-tcnn (pyproject.toml) did not run successfully.
  │ exit code: 1
  ╰─> [70 lines of output]
      WARNING setuptools_scm._integration.setuptools pyproject.toml does not contain a tool.setuptools_scm section
      running bdist_wheel
      running build
      running build_py
      creating build
      creating build/lib.linux-x86_64-cpython-310
      creating build/lib.linux-x86_64-cpython-310/jaxtcnn
      copying src/jaxtcnn/__init__.py -> build/lib.linux-x86_64-cpython-310/jaxtcnn
      creating build/lib.linux-x86_64-cpython-310/jaxtcnn/hashgrid_tcnn
      copying src/jaxtcnn/hashgrid_tcnn/impl.py -> build/lib.linux-x86_64-cpython-310/jaxtcnn/hashgrid_tcnn
      copying src/jaxtcnn/hashgrid_tcnn/lowering.py -> build/lib.linux-x86_64-cpython-310/jaxtcnn/hashgrid_tcnn
      copying src/jaxtcnn/hashgrid_tcnn/__init__.py -> build/lib.linux-x86_64-cpython-310/jaxtcnn/hashgrid_tcnn
      copying src/jaxtcnn/hashgrid_tcnn/abstract.py -> build/lib.linux-x86_64-cpython-310/jaxtcnn/hashgrid_tcnn
      running egg_info
      writing src/jax_tcnn.egg-info/PKG-INFO
      writing dependency_links to src/jax_tcnn.egg-info/dependency_links.txt
      writing requirements to src/jax_tcnn.egg-info/requires.txt
      writing top-level names to src/jax_tcnn.egg-info/top_level.txt
      writing manifest file 'src/jax_tcnn.egg-info/SOURCES.txt'
      running build_ext
      Traceback (most recent call last):
        File "/home/albert/mambaforge/envs/pong_new/lib/python3.10/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 353, in <module>
          main()
        File "/home/albert/mambaforge/envs/pong_new/lib/python3.10/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 335, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
        File "/home/albert/mambaforge/envs/pong_new/lib/python3.10/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 251, in build_wheel
          return _build_backend().build_wheel(wheel_directory, config_settings,
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/build_meta.py", line 434, in build_wheel
          return self._build_with_temp_dir(
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/build_meta.py", line 419, in _build_with_temp_dir
          self.run_setup()
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/build_meta.py", line 341, in run_setup
          exec(code, locals())
        File "<string>", line 86, in <module>
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/__init__.py", line 103, in setup
          return distutils.core.setup(**attrs)
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 185, in setup
          return run_commands(dist)
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/core.py", line 201, in run_commands
          dist.run_commands()
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 969, in run_commands
          self.run_command(cmd)
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/dist.py", line 989, in run_command
          super().run_command(command)
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
          cmd_obj.run()
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/wheel/bdist_wheel.py", line 364, in run
          self.run_command("build")
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 318, in run_command
          self.distribution.run_command(command)
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/dist.py", line 989, in run_command
          super().run_command(command)
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
          cmd_obj.run()
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/command/build.py", line 131, in run
          self.run_command(cmd_name)
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/cmd.py", line 318, in run_command
          self.distribution.run_command(command)
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/dist.py", line 989, in run_command
          super().run_command(command)
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/dist.py", line 988, in run_command
          cmd_obj.run()
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/command/build_ext.py", line 88, in run
          _build_ext.run(self)
        File "/tmp/pip-build-env-v74q8f5x/overlay/lib/python3.10/site-packages/setuptools/_distutils/command/build_ext.py", line 345, in run
          self.build_extensions()
        File "<string>", line 54, in build_extensions
        File "/home/albert/mambaforge/envs/pong_new/lib/python3.10/os.py", line 680, in __getitem__
          raise KeyError(key) from None
      KeyError: 'cmakeFlags'
      [end of output]
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
  ERROR: Failed building wheel for jax-tcnn
Failed to build jax-tcnn
ERROR: Could not build wheels for jax-tcnn, which is required to install pyproject.toml-based projects

Some guidance on installation would be helpful! If it helps, I only really need a jax version of the TCNN hash encoder.

@blurgyy
Copy link
Owner

blurgyy commented Oct 17, 2023

Hi @alberthli, I think there are several steps should be done to install this package:

  1. First of all, obtain a local copy of this repository
  2. The error about the environment variable can be avoided by removing the + os.environ["cmakeFlags"].split() part here
  3. To build the binding, you also have to copy the serde-helper.h file (here) to the directory deps/jax-tcnn/lib, and include it with relative paths in deps/jax-tcnn/lib/ffi.cc and deps/jax-tcnn/lib/impl/hashgrid.cu
  4. You should build tiny-cuda-nn's library first (follow tiny-cuda-nn's instructions), and probably modify deps/jax-tcnn/CMakeLists.txt so that CMake can find tiny-cuda-nn's library
  5. tiny-cuda-nn's include/ directory should be in the search path in your build environment (one way to specify this is to supply the argument -I/path/to/tiny-cuda-nn/include to the compiler)

@blurgyy
Copy link
Owner

blurgyy commented Oct 17, 2023

Here's a modified CMakeLists.txt for pip installation:

cmake_minimum_required(VERSION 3.23)
project(volume_rendering_jax LANGUAGES CXX CUDA)
# use `cmake -DCMAKE_CUDA_ARCHITECTURES=61;62;75` to build for compute capabilities 61, 62, and 75
# set(CMAKE_CUDA_ARCHITECTURES "all")
message(STATUS "Enabled CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")

message(STATUS "Using CMake version " ${CMAKE_VERSION})

set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --extended-lambda")

find_package(Python COMPONENTS Interpreter Development REQUIRED)
find_package(pybind11 CONFIG REQUIRED)
find_package(fmt REQUIRED)

include_directories(${CMAKE_CURRENT_LIST_DIR}/lib)

include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
pybind11_add_module(
    tcnnutils
    ${CMAKE_CURRENT_LIST_DIR}/lib/impl/hashgrid.cu
    ${CMAKE_CURRENT_LIST_DIR}/lib/ffi.cc
)

# e.g. `cmake -DTCNN_MIN_GPU_ARCH=61`
message(STATUS "TCNN_MIN_GPU_ARCH=35")
target_compile_definitions(tcnnutils PUBLIC -DTCNN_MIN_GPU_ARCH=35)

target_link_libraries(tcnnutils PRIVATE tiny-cuda-nn fmt::fmt)

install(TARGETS tcnnutils DESTINATION jaxtcnn)

To add all required headers, you can first clone tiny-cuda-nn (with all the submodules) and check out to the v1.6 tag, then go to deps/jax-tcnn/lib and symlink the directories:

$ git clone https://github.com/nvlabs/tiny-cuda-nn.git --recursive
$ cd tiny-cuda-nn
$ git checkout v1.6
$ cd ..
$ git clone https://github.com/blurgyy/jaxngp.git
$ cd jaxngp/deps/jax-tcnn/lib
$ ln -s /path/to/tiny-cuda-nn/include/tiny-cuda-nn /path/to/tiny-cuda-nn/dependencies/* .

You should then build tiny-cuda-nn to obtain the static library libtiny-cuda-nn.a, putting it to a proper location in your system, and run pip install -v /path/to/jaxngp/deps/jax-tcnn.

@alberthli
Copy link
Author

alberthli commented Jan 30, 2024

@blurgyy Thanks for the edits/instructions - I finally got back around to looking at this and successfully built jax-tcnn.

I have a follow-up question: my goal is to train a nerfacto model from nerfstudio using the tcnn fully fused MLP (which calls the torch bindings) and then load those weights into a jax model that just needs to do inference. The reason for this is that I want the NeRF density field to be used in conjunction with a large jax codebase, but I also want to reduce the amount of time it takes to train the NeRF, and nerfstudio is very entrenched in the NeRF side of our experimental pipeline.

However, when I initialize jaxtcnn TCNNHashGridEncoder and CoordinateBasedMLP modules, I find that the total number of parameters is exactly 24 fewer than the tcnn.NetworkWithInputEncoding (12196216 + 3072 parameters vs. 12199312). I have also confirmed that the entire difference is accounted for by the hash grid and not the MLP by forcing nerfstudio to train a NeRF by separating the encoder and MLP and recounting the parameters. I have double checked that the parameters I use to initialize the encoders are the same across the jax and torch versions.

Do you have any idea where this parameter discrepancy comes from in the hash grid implementation? Could this be resolved just by building the most recent version of tiny-cuda-nn instead of v1.6 or would there be modifications needed on the jax-tcnn side? I would be happy to provide more information and relevant files if needed.

EDIT: In my fork of this branch, I've made some minor modifications to allow jax-tcnn to be compatible with the latest version of jax and jaxlib as well as the latest version of tiny-cuda-nn, just in case there were differences in versioning causing the discrepancy. I re-built tcnn and jax-tcnn with these changes and the same parameter discrepancy exists. See the diff between your repo and my fork here.

To initialize the encoder, I'm using the parameters

L=16
F=2
T=2**19
N_min=16
N_max=2048

@blurgyy
Copy link
Owner

blurgyy commented May 22, 2024

Hi @alberthli,

I recently encountered a use case where I need to calibrate the JAX's HashGridEncoder (not TCNNHashGridEncoder)'s parameter layout with that of tiny-cuda-nn's, I just pushed an update to address this (commit 04bcea2).

I used the following parameters to initialize and train a HashGridEncoder from the JAX side, and used tcnn's pytorch bindings to load it, the parameter count and per-layer hashgrid output is checked to match (there are still absolute error not larger than 1e-3).

  • initialization from jaxngp (this repo):

    from models.encoders import HashGridEncoder
    
    jax_hg = HashGridEncoder(L=16, T=2**19, F=2, N_min=32, N_max=2048, tv_scale=0.)
  • initializing using tcnn's pytorch bindings:

    import math
    
    import tinycudann as tcnn
    import torch
    
    L = 16
    F = 2
    N_min = 32
    N_max = 2048
    encoding_config = {
        "otype": "HashGrid",
        "n_levels": L,
        "n_features_per_level": F,
        "log2_hashmap_size": 19,
        "base_resolution": N_min,
        "per_level_scale": math.exp((math.log(N_max) - math.log(N_min)) / (L - 1)),
        "interpolation": "Linear",
    }
    
    tcnn_hg = tcnn.Encoding(n_input_dims=3, encoding_config=encoding_config, dtype=torch.float32)
  • load parameters from jaxngp's hashgrid into tcnn's:

    state_dict = {
        "params": torch.as_tensor(jax_hg_params_dict["latent codes stored on grid vertices"].ravel()).to("cuda"),
    }
    tcnn_hg.load_state_dict(state_dict)

The jax_hg_params_dict is the hashgrid encoder's parameters, it can be saved to disk using numpy.save and loaded using numpy.load(path_to_the_npy_file).item() to get the dict object. Note that the parameter key is literally latent codes stored on grid vertices, see here.

I hope this helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants