Skip to content

Commit

Permalink
Don't use ccalls where wrappers already exist (arb mat edition) (#1918)
Browse files Browse the repository at this point in the history
* Don't ccall `acb_mat_solve`
* Don't ccall `acb_mat_solve_lu_precomp`
* Don't ccall `acb_mat_lu`
* Don't ccall `arb_mat_solve`
* Don't ccall `arb_mat_solve_lu_precomp`
* Don't ccall `arb_mat_lu`
  • Loading branch information
lgoettgens authored Oct 24, 2024
1 parent 285bae8 commit 8cade50
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 92 deletions.
33 changes: 10 additions & 23 deletions src/arb/ComplexMat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -525,17 +525,21 @@ end
#
###############################################################################

function lu!(P::Perm, x::ComplexMatrix)
function lu!(P::Perm, z::ComplexMatrix, x::ComplexMatrix)
P.d .-= 1
r = ccall((:acb_mat_lu, libflint), Cint,
(Ptr{Int}, Ref{ComplexMatrix}, Ref{ComplexMatrix}, Int),
P.d, x, x, precision(Balls))
P.d, z, x, precision(Balls))
r == 0 && error("Could not find $(nrows(x)) invertible pivot elements")
P.d .+= 1
inv!(P)
return min(nrows(x), ncols(x))
end

function lu!(P::Perm, x::ComplexMatrix)
return lu!(P, x, x)
end

function _solve!(z::ComplexMatrix, x::ComplexMatrix, y::ComplexMatrix)
r = ccall((:acb_mat_solve, libflint), Cint,
(Ref{ComplexMatrix}, Ref{ComplexMatrix}, Ref{ComplexMatrix}, Int),
Expand Down Expand Up @@ -570,10 +574,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, A::ComplexMatrix, b
end

x = similar(A, ncols(A), ncols(b))
fl = ccall((:acb_mat_solve, libflint), Cint,
(Ref{ComplexMatrix}, Ref{ComplexMatrix}, Ref{ComplexMatrix}, Int),
x, A, b, precision(Balls))
fl == 0 && error("Matrix cannot be inverted numerically")
_solve!(x, A, b)
if task === :only_check || task === :with_solution
return true, x, zero(A, 0, 0)
end
Expand All @@ -598,13 +599,7 @@ function Solve._init_reduce(C::Solve.SolveCtx{ComplexFieldElem, Solve.LUTrait})
A = matrix(C)
P = Perm(nrows(C))
x = similar(A, nrows(A), ncols(A))
P.d .-= 1
fl = ccall((:acb_mat_lu, libflint), Cint,
(Ptr{Int}, Ref{ComplexMatrix}, Ref{ComplexMatrix}, Int),
P.d, x, A, precision(Balls))
fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements")
P.d .+= 1
inv!(P)
lu!(P, x, A)

C.red = x
C.lu_perm = P
Expand All @@ -621,13 +616,7 @@ function Solve._init_reduce_transpose(C::Solve.SolveCtx{ComplexFieldElem, Solve.
A = transpose(matrix(C))
P = Perm(nrows(C))
x = similar(A, nrows(A), ncols(A))
P.d .-= 1
fl = ccall((:acb_mat_lu, libflint), Cint,
(Ptr{Int}, Ref{ComplexMatrix}, Ref{ComplexMatrix}, Int),
P.d, x, A, precision(Balls))
fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements")
P.d .+= 1
inv!(P)
lu!(P, x, A)

C.red_transp = x
C.lu_perm_transp = P
Expand All @@ -645,9 +634,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, C::Solve.SolveCtx{C
end

x = similar(b, ncols(C), ncols(b))
ccall((:acb_mat_solve_lu_precomp, libflint), Nothing,
(Ref{ComplexMatrix}, Ptr{Int}, Ref{ComplexMatrix}, Ref{ComplexMatrix}, Int),
x, inv(p).d .- 1, LU, b, precision(Balls))
_solve_lu_precomp!(x, p, LU, b)

if side === :left
x = transpose(x)
Expand Down
33 changes: 10 additions & 23 deletions src/arb/RealMat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -467,18 +467,22 @@ end
#
###############################################################################

function lu!(P::Perm, x::RealMatrix)
function lu!(P::Perm, z::RealMatrix, x::RealMatrix)
parent(P).n != nrows(x) && error("Permutation does not match matrix")
P.d .-= 1
r = ccall((:arb_mat_lu, libflint), Cint,
(Ptr{Int}, Ref{RealMatrix}, Ref{RealMatrix}, Int),
P.d, x, x, precision(Balls))
P.d, z, x, precision(Balls))
r == 0 && error("Could not find $(nrows(x)) invertible pivot elements")
P.d .+= 1
inv!(P)
return min(nrows(x), ncols(x))
end

function lu!(P::Perm, x::RealMatrix)
return lu!(P, x, x)
end

function _solve!(z::RealMatrix, x::RealMatrix, y::RealMatrix)
r = ccall((:arb_mat_solve, libflint), Cint,
(Ref{RealMatrix}, Ref{RealMatrix}, Ref{RealMatrix}, Int),
Expand Down Expand Up @@ -513,10 +517,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, A::RealMatrix, b::R
end

x = similar(A, ncols(A), ncols(b))
fl = ccall((:arb_mat_solve, libflint), Cint,
(Ref{RealMatrix}, Ref{RealMatrix}, Ref{RealMatrix}, Int),
x, A, b, precision(Balls))
fl == 0 && error("Matrix cannot be inverted numerically")
_solve!(x, A, b)
if task === :only_check || task === :with_solution
return true, x, zero(A, 0, 0)
end
Expand All @@ -540,13 +541,7 @@ function Solve._init_reduce(C::Solve.SolveCtx{RealFieldElem, Solve.LUTrait})
A = matrix(C)
P = Perm(nrows(C))
x = similar(A, nrows(A), ncols(A))
P.d .-= 1
fl = ccall((:arb_mat_lu, libflint), Cint,
(Ptr{Int}, Ref{RealMatrix}, Ref{RealMatrix}, Int),
P.d, x, A, precision(Balls))
fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements")
P.d .+= 1
inv!(P)
lu!(P, x, A)

C.red = x
C.lu_perm = P
Expand All @@ -563,13 +558,7 @@ function Solve._init_reduce_transpose(C::Solve.SolveCtx{RealFieldElem, Solve.LUT
A = transpose(matrix(C))
P = Perm(nrows(C))
x = similar(A, nrows(A), ncols(A))
P.d .-= 1
fl = ccall((:arb_mat_lu, libflint), Cint,
(Ptr{Int}, Ref{RealMatrix}, Ref{RealMatrix}, Int),
P.d, x, A, precision(Balls))
fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements")
P.d .+= 1
inv!(P)
lu!(P, x, A)

C.red_transp = x
C.lu_perm_transp = P
Expand All @@ -587,9 +576,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, C::Solve.SolveCtx{R
end

x = similar(b, ncols(C), ncols(b))
ccall((:arb_mat_solve_lu_precomp, libflint), Nothing,
(Ref{RealMatrix}, Ptr{Int}, Ref{RealMatrix}, Ref{RealMatrix}, Int),
x, inv(p).d .- 1, LU, b, precision(Balls))
_solve_lu_precomp!(x, p, LU, b)

if side === :left
x = transpose(x)
Expand Down
33 changes: 10 additions & 23 deletions src/arb/acb_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,17 +528,21 @@ end
#
###############################################################################

function lu!(P::Perm, x::AcbMatrix)
function lu!(P::Perm, z::AcbMatrix, x::AcbMatrix)
P.d .-= 1
r = ccall((:acb_mat_lu, libflint), Cint,
(Ptr{Int}, Ref{AcbMatrix}, Ref{AcbMatrix}, Int),
P.d, x, x, precision(base_ring(x)))
P.d, z, x, precision(base_ring(x)))
r == 0 && error("Could not find $(nrows(x)) invertible pivot elements")
P.d .+= 1
inv!(P)
return nrows(x)
end

function lu!(P::Perm, x::AcbMatrix)
return lu!(P, x, x)
end

function _solve!(z::AcbMatrix, x::AcbMatrix, y::AcbMatrix)
r = ccall((:acb_mat_solve, libflint), Cint,
(Ref{AcbMatrix}, Ref{AcbMatrix}, Ref{AcbMatrix}, Int),
Expand Down Expand Up @@ -573,10 +577,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, A::AcbMatrix, b::Ac
end

x = similar(A, ncols(A), ncols(b))
fl = ccall((:acb_mat_solve, libflint), Cint,
(Ref{AcbMatrix}, Ref{AcbMatrix}, Ref{AcbMatrix}, Int),
x, A, b, precision(base_ring(A)))
fl == 0 && error("Matrix cannot be inverted numerically")
_solve!(x, A, b)
if task === :only_check || task === :with_solution
return true, x, zero(A, 0, 0)
end
Expand All @@ -600,13 +601,7 @@ function Solve._init_reduce(C::Solve.SolveCtx{AcbFieldElem, Solve.LUTrait})
A = matrix(C)
P = Perm(nrows(C))
x = similar(A, nrows(A), ncols(A))
P.d .-= 1
fl = ccall((:acb_mat_lu, libflint), Cint,
(Ptr{Int}, Ref{AcbMatrix}, Ref{AcbMatrix}, Int),
P.d, x, A, precision(base_ring(A)))
fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements")
P.d .+= 1
inv!(P)
lu!(P, x, A)

C.red = x
C.lu_perm = P
Expand All @@ -623,13 +618,7 @@ function Solve._init_reduce_transpose(C::Solve.SolveCtx{AcbFieldElem, Solve.LUTr
A = transpose(matrix(C))
P = Perm(nrows(C))
x = similar(A, nrows(A), ncols(A))
P.d .-= 1
fl = ccall((:acb_mat_lu, libflint), Cint,
(Ptr{Int}, Ref{AcbMatrix}, Ref{AcbMatrix}, Int),
P.d, x, A, precision(base_ring(A)))
fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements")
P.d .+= 1
inv!(P)
lu!(P, x, A)

C.red_transp = x
C.lu_perm_transp = P
Expand All @@ -647,9 +636,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, C::Solve.SolveCtx{A
end

x = similar(b, ncols(C), ncols(b))
ccall((:acb_mat_solve_lu_precomp, libflint), Nothing,
(Ref{AcbMatrix}, Ptr{Int}, Ref{AcbMatrix}, Ref{AcbMatrix}, Int),
x, inv(p).d .- 1, LU, b, precision(base_ring(LU)))
_solve_lu_precomp!(x, p, LU, b)

if side === :left
x = transpose(x)
Expand Down
33 changes: 10 additions & 23 deletions src/arb/arb_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -480,18 +480,22 @@ function cholesky(x::ArbMatrix)
return z
end

function lu!(P::Perm, x::ArbMatrix)
function lu!(P::Perm, z::ArbMatrix, x::ArbMatrix)
parent(P).n != nrows(x) && error("Permutation does not match matrix")
P.d .-= 1
r = ccall((:arb_mat_lu, libflint), Cint,
(Ptr{Int}, Ref{ArbMatrix}, Ref{ArbMatrix}, Int),
P.d, x, x, precision(base_ring(x)))
P.d, z, x, precision(base_ring(x)))
r == 0 && error("Could not find $(nrows(x)) invertible pivot elements")
P.d .+= 1
inv!(P)
return nrows(x)
end

function lu!(P::Perm, x::ArbMatrix)
return lu!(P, x, x)
end

function _solve!(z::ArbMatrix, x::ArbMatrix, y::ArbMatrix)
r = ccall((:arb_mat_solve, libflint), Cint,
(Ref{ArbMatrix}, Ref{ArbMatrix}, Ref{ArbMatrix}, Int),
Expand Down Expand Up @@ -540,10 +544,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, A::ArbMatrix, b::Ar
end

x = similar(A, ncols(A), ncols(b))
fl = ccall((:arb_mat_solve, libflint), Cint,
(Ref{ArbMatrix}, Ref{ArbMatrix}, Ref{ArbMatrix}, Int),
x, A, b, precision(base_ring(A)))
fl == 0 && error("Matrix cannot be inverted numerically")
_solve!(x, A, b)
if task === :only_check || task === :with_solution
return true, x, zero(A, 0, 0)
end
Expand All @@ -567,13 +568,7 @@ function Solve._init_reduce(C::Solve.SolveCtx{ArbFieldElem, Solve.LUTrait})
A = matrix(C)
P = Perm(nrows(C))
x = similar(A, nrows(A), ncols(A))
P.d .-= 1
fl = ccall((:arb_mat_lu, libflint), Cint,
(Ptr{Int}, Ref{ArbMatrix}, Ref{ArbMatrix}, Int),
P.d, x, A, precision(base_ring(A)))
fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements")
P.d .+= 1
inv!(P)
lu!(P, x, A)

C.red = x
C.lu_perm = P
Expand All @@ -590,13 +585,7 @@ function Solve._init_reduce_transpose(C::Solve.SolveCtx{ArbFieldElem, Solve.LUTr
A = transpose(matrix(C))
P = Perm(nrows(C))
x = similar(A, nrows(A), ncols(A))
P.d .-= 1
fl = ccall((:arb_mat_lu, libflint), Cint,
(Ptr{Int}, Ref{ArbMatrix}, Ref{ArbMatrix}, Int),
P.d, x, A, precision(base_ring(A)))
fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements")
P.d .+= 1
inv!(P)
lu!(P, x, A)

C.red_transp = x
C.lu_perm_transp = P
Expand All @@ -614,9 +603,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, C::Solve.SolveCtx{A
end

x = similar(b, ncols(C), ncols(b))
ccall((:arb_mat_solve_lu_precomp, libflint), Nothing,
(Ref{ArbMatrix}, Ptr{Int}, Ref{ArbMatrix}, Ref{ArbMatrix}, Int),
x, inv(p).d .- 1, LU, b, precision(base_ring(LU)))
_solve_lu_precomp!(x, p, LU, b)

if side === :left
x = transpose(x)
Expand Down

0 comments on commit 8cade50

Please sign in to comment.