From bb7d3653b89cef4e5b4e6022f289e5dd1d4f99d9 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Tue, 1 Oct 2024 23:46:40 +0200 Subject: [PATCH] Add ForwardDiff package extension (#200) --- Project.toml | 12 +++++++---- ext/SparseConnectivityTracerForwardDiffExt.jl | 14 +++++++++++++ src/SparseConnectivityTracer.jl | 13 +++++++----- src/overloads/utils.jl | 4 ++-- test/ext/test_ForwardDiff.jl | 20 +++++++++++++++++++ test/linting.jl | 1 + test/runtests.jl | 2 +- 7 files changed, 54 insertions(+), 12 deletions(-) create mode 100644 ext/SparseConnectivityTracerForwardDiffExt.jl create mode 100644 test/ext/test_ForwardDiff.jl diff --git a/Project.toml b/Project.toml index 6b952f8..24f178d 100644 --- a/Project.toml +++ b/Project.toml @@ -14,16 +14,18 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] SparseConnectivityTracerDataInterpolationsExt = "DataInterpolations" +SparseConnectivityTracerForwardDiffExt = "ForwardDiff" SparseConnectivityTracerLogExpFunctionsExt = "LogExpFunctions" -SparseConnectivityTracerNaNMathExt = "NaNMath" SparseConnectivityTracerNNlibExt = "NNlib" +SparseConnectivityTracerNaNMathExt = "NaNMath" SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions" [compat] @@ -31,10 +33,11 @@ ADTypes = "1" DataInterpolations = "6.2" DocStringExtensions = "0.9" FillArrays = "1" +ForwardDiff = "0.10" LinearAlgebra = "<0.0.1, 1" LogExpFunctions = "0.3.28" -NaNMath = "1" NNlib = "0.8, 0.9" +NaNMath = "1" Random = "<0.0.1, 1" Requires = "1.3" SparseArrays = "<0.0.1, 1" @@ -43,7 +46,8 @@ julia = "1.6" [extras] DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" diff --git a/ext/SparseConnectivityTracerForwardDiffExt.jl b/ext/SparseConnectivityTracerForwardDiffExt.jl new file mode 100644 index 0000000..d30d215 --- /dev/null +++ b/ext/SparseConnectivityTracerForwardDiffExt.jl @@ -0,0 +1,14 @@ +module SparseConnectivityTracerForwardDiffExt + +if isdefined(Base, :get_extension) + import SparseConnectivityTracer as SCT + using ForwardDiff: ForwardDiff +else + import ..SparseConnectivityTracer as SCT + using ..ForwardDiff: ForwardDiff +end + +# Overload 2-to-1 functions on ForwardDiff.Dual +eval(SCT.generate_code_2_to_1_typed(:Base, SCT.ops_2_to_1, ForwardDiff.Dual)) + +end # module diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 8a44b8b..3cf804e 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -44,11 +44,8 @@ export jacobian_sparsity, hessian_sparsity function __init__() @static if !isdefined(Base, :get_extension) - @require SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" include( - "../ext/SparseConnectivityTracerSpecialFunctionsExt.jl" - ) - @require NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" include( - "../ext/SparseConnectivityTracerNNlibExt.jl" + @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include( + "../ext/SparseConnectivityTracerForwardDiffExt.jl" ) @require LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" include( "../ext/SparseConnectivityTracerLogExpFunctionsExt.jl" @@ -56,6 +53,12 @@ function __init__() @require NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" include( "../ext/SparseConnectivityTracerNaNMathExt.jl" ) + @require NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" include( + "../ext/SparseConnectivityTracerNNlibExt.jl" + ) + @require SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" include( + "../ext/SparseConnectivityTracerSpecialFunctionsExt.jl" + ) # NOTE: SparseConnectivityTracerDataInterpolationsExt is not loaded on Julia <1.10 end end diff --git a/src/overloads/utils.jl b/src/overloads/utils.jl index fdb5372..585b712 100644 --- a/src/overloads/utils.jl +++ b/src/overloads/utils.jl @@ -6,7 +6,7 @@ for d in dims g = Symbol("generate_code_gradient_", d) h = Symbol("generate_code_hessian_", d) - @eval function $f(M::Symbol, f) + @eval function $f(M::Symbol, f::Function) expr_g = $g(M, f) expr_h = $h(M, f) return Expr(:block, expr_g, expr_h) @@ -28,7 +28,7 @@ for d in dims end # Overloads of 2-argument functions on arbitrary types -function generate_code_2_to_1_typed(M::Symbol, f, Z::Type) +function generate_code_2_to_1_typed(M::Symbol, f::Function, Z::Type) expr_g = generate_code_gradient_2_to_1_typed(M, f, Z) expr_h = generate_code_hessian_2_to_1_typed(M, f, Z) return Expr(:block, expr_g, expr_h) diff --git a/test/ext/test_ForwardDiff.jl b/test/ext/test_ForwardDiff.jl new file mode 100644 index 0000000..e94999b --- /dev/null +++ b/test/ext/test_ForwardDiff.jl @@ -0,0 +1,20 @@ +using SparseConnectivityTracer +using ForwardDiff: ForwardDiff + +using Test + +d = ForwardDiff.Dual{ForwardDiff.Tag{*,Float64}}(1.2, 3.4) +@testset "$D" for D in (TracerSparsityDetector, TracerLocalSparsityDetector) + detector = D() + # Testing on multiplication ensures that methods from Base have been overloaded, + # Since this would otherwise throw an ambiguity error: + # https://github.com/adrhill/SparseConnectivityTracer.jl/issues/196 + @testset "Jacobian" begin + @test jacobian_sparsity(x -> x * d, 1.0, detector) ≈ [1;;] + @test jacobian_sparsity(x -> d * x, 1.0, detector) ≈ [1;;] + end + @testset "Hessian" begin + @test hessian_sparsity(x -> x * d, 1.0, detector) ≈ [0;;] + @test hessian_sparsity(x -> d * x, 1.0, detector) ≈ [0;;] + end +end diff --git a/test/linting.jl b/test/linting.jl index 4ae6a26..8f57052 100644 --- a/test/linting.jl +++ b/test/linting.jl @@ -7,6 +7,7 @@ using JET: JET using ExplicitImports: ExplicitImports # Load package extensions so they get tested by ExplicitImports.jl +using ForwardDiff: ForwardDiff using DataInterpolations: DataInterpolations using NaNMath: NaNMath using NNlib: NNlib diff --git a/test/runtests.jl b/test/runtests.jl index c2eb073..3db4bb2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -72,7 +72,7 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core") if GROUP in ("Core", "All") @info "Testing package extensions..." @testset verbose = true "Package extensions" begin - for ext in (:NNlib, :SpecialFunctions, :LogExpFunctions, :NaNMath) + for ext in (:ForwardDiff, :LogExpFunctions, :NaNMath, :NNlib, :SpecialFunctions) @testset "$ext" begin @info "...$ext" include("ext/test_$ext.jl")