From e7f32d74ec4233cdc514e06ec7ed9752c4e7f293 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Thu, 2 Nov 2023 16:45:32 +0100 Subject: [PATCH 1/3] Add transformations for bases - incidental: fix docs, remove unused function arguments --- docs/src/index.md | 15 ++++---- src/generic_api.jl | 76 ++++++++++++++++++++++++++-------------- test/test_derivatives.jl | 7 ++-- test/test_generic_api.jl | 5 +-- 4 files changed, 64 insertions(+), 39 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index add4952..a7a6a25 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -40,7 +40,7 @@ Bases have a “canonical” domain, eg ``[-1,1]`` or ``[-1,1]^n`` for Chebyshev ### Univariate family on `[-1,1]` -```@repl +```@example using SpectralKit basis = Chebyshev(EndpointGrid(), 5) # 5 Chebyshev polynomials is_function_basis(basis) # ie we support the interface below @@ -59,18 +59,19 @@ augment_coefficients(basis, basis2, θ) # … so let's do it ### Smolyak approximation on a transformed domain -```@repl +```@example using SpectralKit, StaticArrays function f2(x) # bivariate function we approximate x1, x2 = x # takes vectors exp(x1) + exp(-abs2(x2)) end -ct = coordinate_transformations(BoundedLinear(-1, 2.0), SemiInfRational(-3.0, 3.0)) basis = smolyak_basis(Chebyshev, InteriorGrid2(), SmolyakParameters(3), 2) -x = grid(basis) -θ = collocation_matrix(basis) \ f2.(from_pm1.(ct, x)) # find the coefficients -z = (0.5, 0.7) # evaluate at this point -isapprox(f2(z), (linear_combination(basis, θ) ∘ ct)(z), rtol = 0.005) +ct = coordinate_transformations(BoundedLinear(-1, 2.0), SemiInfRational(-3.0, 3.0)) +basis_t = basis ∘ ct +x = grid(basis_t) +θ = collocation_matrix(basis_t) \ f2.(x) # find the coefficients +z = (0.5, 0.7) # evaluate at this point +isapprox(f2(z), linear_combination(basis_t, θ)(z), rtol = 0.005) ``` Note how the transformation can be combined with `∘` to a callable that evaluates a transformed linear combination at `z`. diff --git a/src/generic_api.jl b/src/generic_api.jl index c49560e..ce3b8a4 100644 --- a/src/generic_api.jl +++ b/src/generic_api.jl @@ -101,10 +101,10 @@ Internal. """ $(SIGNATURES) -Helper function for linear combinations of basis elements at `x`. When `_check`, check -that `θ` and `basis` have compatible dimensions. +Helper function for linear combinations of basis elements at `x`. Always checks that `θ` +and `basis` have compatible dimensions. """ -@inline function _linear_combination(basis, θ, x, _check) +@inline function _linear_combination(basis, θ, x) # an implementation of mapreduce, to work around # https://github.com/JuliaLang/julia/issues/50735 B = basis_at(basis, x) @@ -128,7 +128,7 @@ the given order. The length of `θ` should equal `dimension(θ)`. """ -linear_combination(basis, θ, x) = _linear_combination(basis, θ, x, true) +linear_combination(basis, θ, x) = _linear_combination(basis, θ, x) # FIXME define a nice Base.show method struct LinearCombination{B,C} @@ -140,37 +140,18 @@ struct LinearCombination{B,C} end end -(l::LinearCombination)(x) = _linear_combination(l.basis, l.θ, x, false) +(l::LinearCombination)(x) = _linear_combination(l.basis, l.θ, x) """ $(SIGNATURES) Return a callable that calculates `linear_combination(basis, θ, x)` when called with `x`. -Use `linear_combination(basis, θ) ∘ transformation` for domain transformations. +You can use `linear_combination(basis, θ) ∘ transformation` for domain transformations, +though working with `basis ∘ transformation` may be preferred. """ linear_combination(basis, θ) = LinearCombination(basis, θ) -struct TransformedLinearCombination{B,C,T} - basis::B - θ::C - transformation::T - function TransformedLinearCombination(basis::B, θ::C, transformation::T) where {B,C,T} - @argcheck dimension(basis) == length(θ) - @argcheck domain_kind(domain(basis)) ≡ domain_kind(T) - new{B,C,T}(basis, θ, transformation) - end -end - -function (l::TransformedLinearCombination)(x) - (; basis, θ, transformation) = l - _linear_combination(basis, θ, transform_to(domain(basis), transformation, x), false) -end - -function Base.:(∘)(l::LinearCombination, transformation) - TransformedLinearCombination(l.basis, l.θ, transformation) -end - """ $(TYPEDEF) @@ -227,7 +208,6 @@ for compatible vectors `y = f.(x)`. Methods are type stable. The elements of `x` can be [`derivatives`](@ref). """ function collocation_matrix(basis, x = grid(basis)) - @argcheck isconcretetype(eltype(x)) N = dimension(basis) M = length(x) C = Matrix{eltype(basis_at(basis, first(x)))}(undef, M, N) @@ -265,3 +245,45 @@ with [`augment_coefficients`](@ref). since they may be in different positions. Always use [`augment_coefficients`](@ref). """ is_subset_basis(basis1::FunctionBasis, basis2::FunctionBasis) = false + +#### +#### transformed basis +#### + +""" +Transform the domain of a basis. +""" +struct TransformedBasis{B,T} <: FunctionBasis + parent::B + transformation::T + function TransformedBasis(parent::B, transformation::T) where {B,T} + @argcheck domain_kind(domain(parent)) ≡ domain_kind(T) + new{B,T}(parent, transformation) + end +end + +function Base.:(∘)(parent::FunctionBasis, transformation) + TransformedBasis(parent, transformation) +end + +domain(basis::TransformedBasis) = domain(basis.parent) + +dimension(basis::TransformedBasis) = dimension(basis.parent) + +function basis_at(basis::TransformedBasis, x) + (; parent, transformation) = basis + basis_at(parent, transform_to(domain(parent), transformation, x)) +end + +function grid(basis::TransformedBasis) + (; parent, transformation) = basis + d = domain(parent) + Iterators.map(x -> transform_to(d, transformation, x), grid(parent)) +end + +function Base.:(∘)(linear_combination::LinearCombination, transformation) + (; basis, θ) = linear_combination + LinearCombination(basis ∘ transformation, θ) +end + +# FIXME add augmentation for transformed bases diff --git a/test/test_derivatives.jl b/test/test_derivatives.jl index e761d7f..3fc6859 100644 --- a/test/test_derivatives.jl +++ b/test/test_derivatives.jl @@ -49,7 +49,7 @@ end for i in 1:100 z = rand_pm1(i) x = transform_from(PM1(), t, z) - ℓ = linear_combination(b, θ) ∘ t + ℓ = linear_combination(b ∘ t, θ) Dx = @inferred ℓ(derivatives(x, Val(D))) @test Dx[0] == ℓ(x) for d in 1:D @@ -71,8 +71,9 @@ end (f, x) -> DD(x2 -> f((x[1], x2, x[3])), x[2]), (f, x) -> DD(x3 -> f((x[1], x[2], x3)), x[3]), (f, x) -> DD(x1 -> DD(x2 -> f((x1, x2, x[3])), x[2]), x[1])] - θ = randn(dimension(b)) - ℓ = linear_combination(b, θ) ∘ t + bt = b ∘ t + θ = randn(dimension(bt)) + ℓ = linear_combination(bt, θ) d = domain(b) for i in 1:100 z = [rand_pm1() for _ in 1:N] diff --git a/test/test_generic_api.jl b/test/test_generic_api.jl index b94dc02..18e986b 100644 --- a/test/test_generic_api.jl +++ b/test/test_generic_api.jl @@ -35,10 +35,11 @@ end θ = randn(10) t = BoundedLinear(1.0, 2.0) l1 = linear_combination(basis, θ) - l2 = l1 ∘ t + l2 = linear_combination(basis ∘ t, θ) + l3 = linear_combination(basis, θ) ∘ t for _ in 1:20 x = rand() + 1.0 - @test l1(transform_to(domain(basis), t, x)) == l2(x) + @test l1(transform_to(domain(basis), t, x)) == l2(x) == l3(x) end end From ae67f6ecb0dd289777f861ddeedcefe8a7a70fd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Thu, 2 Nov 2023 16:45:56 +0100 Subject: [PATCH 2/3] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e2f5567..b5b6dd5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SpectralKit" uuid = "5c252ae7-b5b6-46ab-a016-b0e3d78320b7" authors = ["Tamas K. Papp "] -version = "0.13.0" +version = "0.14.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" From 976d12031c16023307f9ce125f22639e9eb12a4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Thu, 2 Nov 2023 17:05:35 +0100 Subject: [PATCH 3/3] add tests, fix grid transformation --- src/generic_api.jl | 4 ++-- test/test_generic_api.jl | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/generic_api.jl b/src/generic_api.jl index ce3b8a4..74432ff 100644 --- a/src/generic_api.jl +++ b/src/generic_api.jl @@ -266,7 +266,7 @@ function Base.:(∘)(parent::FunctionBasis, transformation) TransformedBasis(parent, transformation) end -domain(basis::TransformedBasis) = domain(basis.parent) +domain(basis::TransformedBasis) = domain(basis.transformation) dimension(basis::TransformedBasis) = dimension(basis.parent) @@ -278,7 +278,7 @@ end function grid(basis::TransformedBasis) (; parent, transformation) = basis d = domain(parent) - Iterators.map(x -> transform_to(d, transformation, x), grid(parent)) + Iterators.map(x -> transform_from(d, transformation, x), grid(parent)) end function Base.:(∘)(linear_combination::LinearCombination, transformation) diff --git a/test/test_generic_api.jl b/test/test_generic_api.jl index 18e986b..0812703 100644 --- a/test/test_generic_api.jl +++ b/test/test_generic_api.jl @@ -29,11 +29,16 @@ end @test_throws ArgumentError linear_combination(basis, bad_θ) end -@testset "transformed linear combinations" begin +@testset "transformed bases and linear combinations" begin N = 10 basis = Chebyshev(EndpointGrid(), N) - θ = randn(10) t = BoundedLinear(1.0, 2.0) + @test domain(basis ∘ t) == domain(t) + @test dimension(basis ∘ t) == dimension(basis) + @test collect(grid(basis ∘ t)) == + [transform_from(domain(basis), t, x) for x in grid(basis)] + + θ = randn(10) l1 = linear_combination(basis, θ) l2 = linear_combination(basis ∘ t, θ) l3 = linear_combination(basis, θ) ∘ t