Skip to content

Commit

Permalink
refactor: change the extension to Lux to support recursively_nillify
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianM-C committed Sep 7, 2024
1 parent 230d165 commit 22b5c9d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 14 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Nemo = "2edaba10-b0f1-5616-af89-8c11ac63239a"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"

[extensions]
SymbolicsForwardDiffExt = "ForwardDiff"
SymbolicsGroebnerExt = "Groebner"
SymbolicsLuxCoreExt = "LuxCore"
SymbolicsLuxExt = "Lux"
SymbolicsNemoExt = "Nemo"
SymbolicsPreallocationToolsExt = ["PreallocationTools", "ForwardDiff"]
SymbolicsSymPyExt = "SymPy"
Expand All @@ -76,7 +76,7 @@ LaTeXStrings = "1.3"
LambertW = "0.4.5"
Latexify = "0.16"
LogExpFunctions = "0.3"
LuxCore = "1"
Lux = "1"
MacroTools = "0.5"
NaNMath = "1"
Nemo = "0.45, 0.46"
Expand Down
11 changes: 0 additions & 11 deletions ext/SymbolicsLuxCoreExt.jl

This file was deleted.

18 changes: 18 additions & 0 deletions ext/SymbolicsLuxExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module SymbolicsLuxExt

using Lux
using Symbolics
using Lux.LuxCore
using Symbolics.SymbolicUtils

function Lux.NilSizePropagation.recursively_nillify(x::SymbolicUtils.BasicSymbolic{<:Vector{<:Real}})
Lux.NilSizePropagation.recursively_nillify(Symbolics.wrap(x))
end

@register_array_symbolic LuxCore.stateless_apply(
model::LuxCore.AbstractLuxLayer, x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin
size = LuxCore.outputsize(model, x, LuxCore.Random.default_rng())
eltype = Real
end

end

0 comments on commit 22b5c9d

Please sign in to comment.