From 8cade506b3214aa8922e88b9deea37b19d3874b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20G=C3=B6ttgens?= Date: Thu, 24 Oct 2024 23:38:03 +0200 Subject: [PATCH] Don't use ccalls where wrappers already exist (arb mat edition) (#1918) * Don't ccall `acb_mat_solve` * Don't ccall `acb_mat_solve_lu_precomp` * Don't ccall `acb_mat_lu` * Don't ccall `arb_mat_solve` * Don't ccall `arb_mat_solve_lu_precomp` * Don't ccall `arb_mat_lu` --- src/arb/ComplexMat.jl | 33 ++++++++++----------------------- src/arb/RealMat.jl | 33 ++++++++++----------------------- src/arb/acb_mat.jl | 33 ++++++++++----------------------- src/arb/arb_mat.jl | 33 ++++++++++----------------------- 4 files changed, 40 insertions(+), 92 deletions(-) diff --git a/src/arb/ComplexMat.jl b/src/arb/ComplexMat.jl index 2fe4be9a3..3cc2ab77c 100644 --- a/src/arb/ComplexMat.jl +++ b/src/arb/ComplexMat.jl @@ -525,17 +525,21 @@ end # ############################################################################### -function lu!(P::Perm, x::ComplexMatrix) +function lu!(P::Perm, z::ComplexMatrix, x::ComplexMatrix) P.d .-= 1 r = ccall((:acb_mat_lu, libflint), Cint, (Ptr{Int}, Ref{ComplexMatrix}, Ref{ComplexMatrix}, Int), - P.d, x, x, precision(Balls)) + P.d, z, x, precision(Balls)) r == 0 && error("Could not find $(nrows(x)) invertible pivot elements") P.d .+= 1 inv!(P) return min(nrows(x), ncols(x)) end +function lu!(P::Perm, x::ComplexMatrix) + return lu!(P, x, x) +end + function _solve!(z::ComplexMatrix, x::ComplexMatrix, y::ComplexMatrix) r = ccall((:acb_mat_solve, libflint), Cint, (Ref{ComplexMatrix}, Ref{ComplexMatrix}, Ref{ComplexMatrix}, Int), @@ -570,10 +574,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, A::ComplexMatrix, b end x = similar(A, ncols(A), ncols(b)) - fl = ccall((:acb_mat_solve, libflint), Cint, - (Ref{ComplexMatrix}, Ref{ComplexMatrix}, Ref{ComplexMatrix}, Int), - x, A, b, precision(Balls)) - fl == 0 && error("Matrix cannot be inverted numerically") + _solve!(x, A, b) if task === :only_check || task === :with_solution return true, x, zero(A, 0, 0) end @@ -598,13 +599,7 @@ function Solve._init_reduce(C::Solve.SolveCtx{ComplexFieldElem, Solve.LUTrait}) A = matrix(C) P = Perm(nrows(C)) x = similar(A, nrows(A), ncols(A)) - P.d .-= 1 - fl = ccall((:acb_mat_lu, libflint), Cint, - (Ptr{Int}, Ref{ComplexMatrix}, Ref{ComplexMatrix}, Int), - P.d, x, A, precision(Balls)) - fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements") - P.d .+= 1 - inv!(P) + lu!(P, x, A) C.red = x C.lu_perm = P @@ -621,13 +616,7 @@ function Solve._init_reduce_transpose(C::Solve.SolveCtx{ComplexFieldElem, Solve. A = transpose(matrix(C)) P = Perm(nrows(C)) x = similar(A, nrows(A), ncols(A)) - P.d .-= 1 - fl = ccall((:acb_mat_lu, libflint), Cint, - (Ptr{Int}, Ref{ComplexMatrix}, Ref{ComplexMatrix}, Int), - P.d, x, A, precision(Balls)) - fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements") - P.d .+= 1 - inv!(P) + lu!(P, x, A) C.red_transp = x C.lu_perm_transp = P @@ -645,9 +634,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, C::Solve.SolveCtx{C end x = similar(b, ncols(C), ncols(b)) - ccall((:acb_mat_solve_lu_precomp, libflint), Nothing, - (Ref{ComplexMatrix}, Ptr{Int}, Ref{ComplexMatrix}, Ref{ComplexMatrix}, Int), - x, inv(p).d .- 1, LU, b, precision(Balls)) + _solve_lu_precomp!(x, p, LU, b) if side === :left x = transpose(x) diff --git a/src/arb/RealMat.jl b/src/arb/RealMat.jl index 5796bdd5f..55745ea9c 100644 --- a/src/arb/RealMat.jl +++ b/src/arb/RealMat.jl @@ -467,18 +467,22 @@ end # ############################################################################### -function lu!(P::Perm, x::RealMatrix) +function lu!(P::Perm, z::RealMatrix, x::RealMatrix) parent(P).n != nrows(x) && error("Permutation does not match matrix") P.d .-= 1 r = ccall((:arb_mat_lu, libflint), Cint, (Ptr{Int}, Ref{RealMatrix}, Ref{RealMatrix}, Int), - P.d, x, x, precision(Balls)) + P.d, z, x, precision(Balls)) r == 0 && error("Could not find $(nrows(x)) invertible pivot elements") P.d .+= 1 inv!(P) return min(nrows(x), ncols(x)) end +function lu!(P::Perm, x::RealMatrix) + return lu!(P, x, x) +end + function _solve!(z::RealMatrix, x::RealMatrix, y::RealMatrix) r = ccall((:arb_mat_solve, libflint), Cint, (Ref{RealMatrix}, Ref{RealMatrix}, Ref{RealMatrix}, Int), @@ -513,10 +517,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, A::RealMatrix, b::R end x = similar(A, ncols(A), ncols(b)) - fl = ccall((:arb_mat_solve, libflint), Cint, - (Ref{RealMatrix}, Ref{RealMatrix}, Ref{RealMatrix}, Int), - x, A, b, precision(Balls)) - fl == 0 && error("Matrix cannot be inverted numerically") + _solve!(x, A, b) if task === :only_check || task === :with_solution return true, x, zero(A, 0, 0) end @@ -540,13 +541,7 @@ function Solve._init_reduce(C::Solve.SolveCtx{RealFieldElem, Solve.LUTrait}) A = matrix(C) P = Perm(nrows(C)) x = similar(A, nrows(A), ncols(A)) - P.d .-= 1 - fl = ccall((:arb_mat_lu, libflint), Cint, - (Ptr{Int}, Ref{RealMatrix}, Ref{RealMatrix}, Int), - P.d, x, A, precision(Balls)) - fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements") - P.d .+= 1 - inv!(P) + lu!(P, x, A) C.red = x C.lu_perm = P @@ -563,13 +558,7 @@ function Solve._init_reduce_transpose(C::Solve.SolveCtx{RealFieldElem, Solve.LUT A = transpose(matrix(C)) P = Perm(nrows(C)) x = similar(A, nrows(A), ncols(A)) - P.d .-= 1 - fl = ccall((:arb_mat_lu, libflint), Cint, - (Ptr{Int}, Ref{RealMatrix}, Ref{RealMatrix}, Int), - P.d, x, A, precision(Balls)) - fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements") - P.d .+= 1 - inv!(P) + lu!(P, x, A) C.red_transp = x C.lu_perm_transp = P @@ -587,9 +576,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, C::Solve.SolveCtx{R end x = similar(b, ncols(C), ncols(b)) - ccall((:arb_mat_solve_lu_precomp, libflint), Nothing, - (Ref{RealMatrix}, Ptr{Int}, Ref{RealMatrix}, Ref{RealMatrix}, Int), - x, inv(p).d .- 1, LU, b, precision(Balls)) + _solve_lu_precomp!(x, p, LU, b) if side === :left x = transpose(x) diff --git a/src/arb/acb_mat.jl b/src/arb/acb_mat.jl index 925442d07..196a47ec7 100644 --- a/src/arb/acb_mat.jl +++ b/src/arb/acb_mat.jl @@ -528,17 +528,21 @@ end # ############################################################################### -function lu!(P::Perm, x::AcbMatrix) +function lu!(P::Perm, z::AcbMatrix, x::AcbMatrix) P.d .-= 1 r = ccall((:acb_mat_lu, libflint), Cint, (Ptr{Int}, Ref{AcbMatrix}, Ref{AcbMatrix}, Int), - P.d, x, x, precision(base_ring(x))) + P.d, z, x, precision(base_ring(x))) r == 0 && error("Could not find $(nrows(x)) invertible pivot elements") P.d .+= 1 inv!(P) return nrows(x) end +function lu!(P::Perm, x::AcbMatrix) + return lu!(P, x, x) +end + function _solve!(z::AcbMatrix, x::AcbMatrix, y::AcbMatrix) r = ccall((:acb_mat_solve, libflint), Cint, (Ref{AcbMatrix}, Ref{AcbMatrix}, Ref{AcbMatrix}, Int), @@ -573,10 +577,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, A::AcbMatrix, b::Ac end x = similar(A, ncols(A), ncols(b)) - fl = ccall((:acb_mat_solve, libflint), Cint, - (Ref{AcbMatrix}, Ref{AcbMatrix}, Ref{AcbMatrix}, Int), - x, A, b, precision(base_ring(A))) - fl == 0 && error("Matrix cannot be inverted numerically") + _solve!(x, A, b) if task === :only_check || task === :with_solution return true, x, zero(A, 0, 0) end @@ -600,13 +601,7 @@ function Solve._init_reduce(C::Solve.SolveCtx{AcbFieldElem, Solve.LUTrait}) A = matrix(C) P = Perm(nrows(C)) x = similar(A, nrows(A), ncols(A)) - P.d .-= 1 - fl = ccall((:acb_mat_lu, libflint), Cint, - (Ptr{Int}, Ref{AcbMatrix}, Ref{AcbMatrix}, Int), - P.d, x, A, precision(base_ring(A))) - fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements") - P.d .+= 1 - inv!(P) + lu!(P, x, A) C.red = x C.lu_perm = P @@ -623,13 +618,7 @@ function Solve._init_reduce_transpose(C::Solve.SolveCtx{AcbFieldElem, Solve.LUTr A = transpose(matrix(C)) P = Perm(nrows(C)) x = similar(A, nrows(A), ncols(A)) - P.d .-= 1 - fl = ccall((:acb_mat_lu, libflint), Cint, - (Ptr{Int}, Ref{AcbMatrix}, Ref{AcbMatrix}, Int), - P.d, x, A, precision(base_ring(A))) - fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements") - P.d .+= 1 - inv!(P) + lu!(P, x, A) C.red_transp = x C.lu_perm_transp = P @@ -647,9 +636,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, C::Solve.SolveCtx{A end x = similar(b, ncols(C), ncols(b)) - ccall((:acb_mat_solve_lu_precomp, libflint), Nothing, - (Ref{AcbMatrix}, Ptr{Int}, Ref{AcbMatrix}, Ref{AcbMatrix}, Int), - x, inv(p).d .- 1, LU, b, precision(base_ring(LU))) + _solve_lu_precomp!(x, p, LU, b) if side === :left x = transpose(x) diff --git a/src/arb/arb_mat.jl b/src/arb/arb_mat.jl index bc934c458..edbfbb6ad 100644 --- a/src/arb/arb_mat.jl +++ b/src/arb/arb_mat.jl @@ -480,18 +480,22 @@ function cholesky(x::ArbMatrix) return z end -function lu!(P::Perm, x::ArbMatrix) +function lu!(P::Perm, z::ArbMatrix, x::ArbMatrix) parent(P).n != nrows(x) && error("Permutation does not match matrix") P.d .-= 1 r = ccall((:arb_mat_lu, libflint), Cint, (Ptr{Int}, Ref{ArbMatrix}, Ref{ArbMatrix}, Int), - P.d, x, x, precision(base_ring(x))) + P.d, z, x, precision(base_ring(x))) r == 0 && error("Could not find $(nrows(x)) invertible pivot elements") P.d .+= 1 inv!(P) return nrows(x) end +function lu!(P::Perm, x::ArbMatrix) + return lu!(P, x, x) +end + function _solve!(z::ArbMatrix, x::ArbMatrix, y::ArbMatrix) r = ccall((:arb_mat_solve, libflint), Cint, (Ref{ArbMatrix}, Ref{ArbMatrix}, Ref{ArbMatrix}, Int), @@ -540,10 +544,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, A::ArbMatrix, b::Ar end x = similar(A, ncols(A), ncols(b)) - fl = ccall((:arb_mat_solve, libflint), Cint, - (Ref{ArbMatrix}, Ref{ArbMatrix}, Ref{ArbMatrix}, Int), - x, A, b, precision(base_ring(A))) - fl == 0 && error("Matrix cannot be inverted numerically") + _solve!(x, A, b) if task === :only_check || task === :with_solution return true, x, zero(A, 0, 0) end @@ -567,13 +568,7 @@ function Solve._init_reduce(C::Solve.SolveCtx{ArbFieldElem, Solve.LUTrait}) A = matrix(C) P = Perm(nrows(C)) x = similar(A, nrows(A), ncols(A)) - P.d .-= 1 - fl = ccall((:arb_mat_lu, libflint), Cint, - (Ptr{Int}, Ref{ArbMatrix}, Ref{ArbMatrix}, Int), - P.d, x, A, precision(base_ring(A))) - fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements") - P.d .+= 1 - inv!(P) + lu!(P, x, A) C.red = x C.lu_perm = P @@ -590,13 +585,7 @@ function Solve._init_reduce_transpose(C::Solve.SolveCtx{ArbFieldElem, Solve.LUTr A = transpose(matrix(C)) P = Perm(nrows(C)) x = similar(A, nrows(A), ncols(A)) - P.d .-= 1 - fl = ccall((:arb_mat_lu, libflint), Cint, - (Ptr{Int}, Ref{ArbMatrix}, Ref{ArbMatrix}, Int), - P.d, x, A, precision(base_ring(A))) - fl == 0 && error("Could not find $(nrows(x)) invertible pivot elements") - P.d .+= 1 - inv!(P) + lu!(P, x, A) C.red_transp = x C.lu_perm_transp = P @@ -614,9 +603,7 @@ function Solve._can_solve_internal_no_check(::Solve.LUTrait, C::Solve.SolveCtx{A end x = similar(b, ncols(C), ncols(b)) - ccall((:arb_mat_solve_lu_precomp, libflint), Nothing, - (Ref{ArbMatrix}, Ptr{Int}, Ref{ArbMatrix}, Ref{ArbMatrix}, Int), - x, inv(p).d .- 1, LU, b, precision(base_ring(LU))) + _solve_lu_precomp!(x, p, LU, b) if side === :left x = transpose(x)