Skip to content

Commit

Permalink
Merge pull request #59 from IHPSystems/feature/update_julia_wrapper_g…
Browse files Browse the repository at this point in the history
…enerator

Updated Julia wrapper generator
  • Loading branch information
ToucheSir authored Nov 11, 2023
2 parents a09c4f6 + c87bd0d commit c6e0070
Show file tree
Hide file tree
Showing 14 changed files with 4,411 additions and 4,275 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Dhairya Gandhi <dhairyagandhi96@gmail.com>"]
version = "0.1.2"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -14,6 +15,7 @@ Torch_jll = "c12fb04c-f5e9-5c82-b5d6-b53f8f8d9a32"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
CEnum = "0.4, 0.5"
FillArrays = "0.8, 0.11, 0.13"
Flux = "0.11"
NNlib = "0.6, 0.7.0 - 0.7.24"
Expand Down
12 changes: 12 additions & 0 deletions deps/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,15 @@ cmake --build .
```

Post this, adding the path to the project via the `LD_LIBRARY_PATH` (and also the CUDNN) binary path might be needed.

# Julia Wrapper

The Julia wrapper is generated from the C wrapper using [Clang.jl](https://github.com/JuliaInterop/Clang.jl).

## Generating

The Julia wrapper can be generated by running:
```sh
cd julia_wrapper_generator
julia --project generator.jl
```
166 changes: 166 additions & 0 deletions deps/julia_wrapper_generator/Manifest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.9.2"
manifest_format = "2.0"
project_hash = "f8ce466db935028989b7e9632ad6b3e067667257"

[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
version = "1.1.1"

[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[deps.CEnum]]
git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.4.2"

[[deps.Clang]]
deps = ["CEnum", "Clang_jll", "Downloads", "Pkg", "TOML"]
git-tree-sha1 = "d78c2973d7a752be377fe173bc9ff2dc2d9c3ed6"
uuid = "40e3b903-d033-50b4-a0cc-940c62c95e31"
version = "0.17.6"

[[deps.Clang_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "TOML", "Zlib_jll", "libLLVM_jll"]
git-tree-sha1 = "124bb00d4ceace456054f17c7cb01e5c8195c609"
uuid = "0ee61d77-7f21-5576-8119-9fcc46b10100"
version = "14.0.6+4"

[[deps.Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[deps.Downloads]]
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
version = "1.6.0"

[[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"

[[deps.InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[deps.JLLWrappers]]
deps = ["Preferences"]
git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.4.1"

[[deps.LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
version = "0.6.3"

[[deps.LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
version = "7.84.0+0"

[[deps.LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
version = "1.10.2+0"

[[deps.Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.28.2+0"

[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
version = "2022.10.11"

[[deps.NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
version = "1.2.0"

[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.9.2"

[[deps.Preferences]]
deps = ["TOML"]
git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1"
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.4.0"

[[deps.Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[deps.REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[deps.Random]]
deps = ["SHA", "Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0"

[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[[deps.Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[deps.TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
version = "1.0.3"

[[deps.Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
version = "1.10.0"

[[deps.UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[deps.Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[deps.Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.13+0"

[[deps.libLLVM_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8f36deef-c2a5-5394-99ed-8e07531fb29a"
version = "14.0.6+3"

[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
version = "1.48.0+0"

[[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
version = "17.4.0+0"
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
[deps]
Clang = "40e3b903-d033-50b4-a0cc-940c62c95e31"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"
42 changes: 42 additions & 0 deletions deps/julia_wrapper_generator/generator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using Clang.Generators

cd(@__DIR__)

include_dir = normpath(@__DIR__, "..", "c_wrapper")

options = load_options(joinpath(@__DIR__, "generator.toml"))

args = get_default_args()
push!(args, "-I$include_dir")

headers = [joinpath(include_dir, "torch_api.h")]

ctx = create_context(headers, args, options)

build!(ctx, BUILDSTAGE_NO_PRINTING)

function rewrite!(e::Expr)
if e.head == :function
rewrite!(e, Val(e.head))
end
end

function rewrite!(e::Expr, ::Val{:function})
rewrite!(e.args[2], Val(e.args[2].head))
end

function rewrite!(e::Expr, ::Val{:block})
e.args[1] = Expr(:macrocall, Symbol("@runtime_error_check"), nothing, e.args[1])
end

function rewrite!(dag::ExprDAG)
for node in get_nodes(dag)
for expr in get_exprs(node)
rewrite!(expr)
end
end
end

rewrite!(ctx.dag)

build!(ctx, BUILDSTAGE_PRINTING_ONLY)
7 changes: 7 additions & 0 deletions deps/julia_wrapper_generator/generator.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[general]
library_name = "libdoeye_caml"
output_file_path = "../../src/wrapper.jl"
prologue_file_path = "./prologue.jl"
module_name = "Wrapper"
jll_pkg_name = "Torch_jll"
export_symbol_prefixes = ["at"]
6 changes: 1 addition & 5 deletions src/error.jl → deps/julia_wrapper_generator/prologue.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
function get_error()
err = cglobal((:myerr, :libdoeye_caml), Cstring) |> unsafe_load
err = cglobal((:myerr, libdoeye_caml), Cstring) |> unsafe_load
unsafe_string(err)
end

function flush_error()
ccall((:flush_error, :libdoeye_caml), Cvoid, ())
end

macro runtime_error_check(ex)
quote
x = $ex
Expand Down
6 changes: 2 additions & 4 deletions src/Torch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@ using FillArrays

TURN_ON_LOGGING = false

# include("wrap2.jl")
include("error.jl")
include("wrapper.jl")

include("wrap/libtorch_common.jl")
include("wrap/libdoeye_caml_generated.jl")
using .Wrapper

# sync + clear empty cache
const clear_cache = at_empty_cache
Expand Down
6 changes: 3 additions & 3 deletions src/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ Tensor(sz::Int; dev = -1) = Tensor(Float32, Int(sz), dev = dev)
# Tensor{T,N}(ptr, on(ptr))
# end

function at_dim(t::Tensor)
function Base.ndims(t::Tensor)
i = Int32[-1]
at_dim(i, t.ptr)
i[1]
Int(i[1])
end

function Base.size(t::Tensor)
dims = at_dim(t)
dims = ndims(t)
sz = zeros(Int32, dims)
at_shape(t.ptr, pointer(sz))
# s = Int.(tuple(sz...))
Expand Down
Loading

0 comments on commit c6e0070

Please sign in to comment.