Skip to content

Commit

Permalink
Merge pull request #46 from tpapp/tp/transformed-bases
Browse files Browse the repository at this point in the history
Transformed bases
  • Loading branch information
tpapp authored Nov 2, 2023
2 parents 5835827 + 976d120 commit eb8519a
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SpectralKit"
uuid = "5c252ae7-b5b6-46ab-a016-b0e3d78320b7"
authors = ["Tamas K. Papp <tkpapp@gmail.com>"]
version = "0.13.0"
version = "0.14.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
15 changes: 8 additions & 7 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Bases have a “canonical” domain, eg ``[-1,1]`` or ``[-1,1]^n`` for Chebyshev

### Univariate family on `[-1,1]`

```@repl
```@example
using SpectralKit
basis = Chebyshev(EndpointGrid(), 5) # 5 Chebyshev polynomials
is_function_basis(basis) # ie we support the interface below
Expand All @@ -59,18 +59,19 @@ augment_coefficients(basis, basis2, θ) # … so let's do it

### Smolyak approximation on a transformed domain

```@repl
```@example
using SpectralKit, StaticArrays
function f2(x) # bivariate function we approximate
x1, x2 = x # takes vectors
exp(x1) + exp(-abs2(x2))
end
ct = coordinate_transformations(BoundedLinear(-1, 2.0), SemiInfRational(-3.0, 3.0))
basis = smolyak_basis(Chebyshev, InteriorGrid2(), SmolyakParameters(3), 2)
x = grid(basis)
θ = collocation_matrix(basis) \ f2.(from_pm1.(ct, x)) # find the coefficients
z = (0.5, 0.7) # evaluate at this point
isapprox(f2(z), (linear_combination(basis, θ) ∘ ct)(z), rtol = 0.005)
ct = coordinate_transformations(BoundedLinear(-1, 2.0), SemiInfRational(-3.0, 3.0))
basis_t = basis ∘ ct
x = grid(basis_t)
θ = collocation_matrix(basis_t) \ f2.(x) # find the coefficients
z = (0.5, 0.7) # evaluate at this point
isapprox(f2(z), linear_combination(basis_t, θ)(z), rtol = 0.005)
```

Note how the transformation can be combined with `` to a callable that evaluates a transformed linear combination at `z`.
Expand Down
76 changes: 49 additions & 27 deletions src/generic_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ Internal.
"""
$(SIGNATURES)
Helper function for linear combinations of basis elements at `x`. When `_check`, check
that `θ` and `basis` have compatible dimensions.
Helper function for linear combinations of basis elements at `x`. Always checks that `θ`
and `basis` have compatible dimensions.
"""
@inline function _linear_combination(basis, θ, x, _check)
@inline function _linear_combination(basis, θ, x)
# an implementation of mapreduce, to work around
# https://github.com/JuliaLang/julia/issues/50735
B = basis_at(basis, x)
Expand All @@ -128,7 +128,7 @@ the given order.
The length of `θ` should equal `dimension(θ)`.
"""
linear_combination(basis, θ, x) = _linear_combination(basis, θ, x, true)
linear_combination(basis, θ, x) = _linear_combination(basis, θ, x)

# FIXME define a nice Base.show method
struct LinearCombination{B,C}
Expand All @@ -140,37 +140,18 @@ struct LinearCombination{B,C}
end
end

(l::LinearCombination)(x) = _linear_combination(l.basis, l.θ, x, false)
(l::LinearCombination)(x) = _linear_combination(l.basis, l.θ, x)

"""
$(SIGNATURES)
Return a callable that calculates `linear_combination(basis, θ, x)` when called with `x`.
Use `linear_combination(basis, θ) ∘ transformation` for domain transformations.
You can use `linear_combination(basis, θ) ∘ transformation` for domain transformations,
though working with `basis ∘ transformation` may be preferred.
"""
linear_combination(basis, θ) = LinearCombination(basis, θ)

struct TransformedLinearCombination{B,C,T}
basis::B
θ::C
transformation::T
function TransformedLinearCombination(basis::B, θ::C, transformation::T) where {B,C,T}
@argcheck dimension(basis) == length(θ)
@argcheck domain_kind(domain(basis)) domain_kind(T)
new{B,C,T}(basis, θ, transformation)
end
end

function (l::TransformedLinearCombination)(x)
(; basis, θ, transformation) = l
_linear_combination(basis, θ, transform_to(domain(basis), transformation, x), false)
end

function Base.:()(l::LinearCombination, transformation)
TransformedLinearCombination(l.basis, l.θ, transformation)
end

"""
$(TYPEDEF)
Expand Down Expand Up @@ -227,7 +208,6 @@ for compatible vectors `y = f.(x)`.
Methods are type stable. The elements of `x` can be [`derivatives`](@ref).
"""
function collocation_matrix(basis, x = grid(basis))
@argcheck isconcretetype(eltype(x))
N = dimension(basis)
M = length(x)
C = Matrix{eltype(basis_at(basis, first(x)))}(undef, M, N)
Expand Down Expand Up @@ -265,3 +245,45 @@ with [`augment_coefficients`](@ref).
since they may be in different positions. Always use [`augment_coefficients`](@ref).
"""
is_subset_basis(basis1::FunctionBasis, basis2::FunctionBasis) = false

####
#### transformed basis
####

"""
Transform the domain of a basis.
"""
struct TransformedBasis{B,T} <: FunctionBasis
parent::B
transformation::T
function TransformedBasis(parent::B, transformation::T) where {B,T}
@argcheck domain_kind(domain(parent)) domain_kind(T)
new{B,T}(parent, transformation)
end
end

function Base.:()(parent::FunctionBasis, transformation)
TransformedBasis(parent, transformation)
end

domain(basis::TransformedBasis) = domain(basis.transformation)

dimension(basis::TransformedBasis) = dimension(basis.parent)

function basis_at(basis::TransformedBasis, x)
(; parent, transformation) = basis
basis_at(parent, transform_to(domain(parent), transformation, x))
end

function grid(basis::TransformedBasis)
(; parent, transformation) = basis
d = domain(parent)
Iterators.map(x -> transform_from(d, transformation, x), grid(parent))
end

function Base.:()(linear_combination::LinearCombination, transformation)
(; basis, θ) = linear_combination
LinearCombination(basis transformation, θ)
end

# FIXME add augmentation for transformed bases
7 changes: 4 additions & 3 deletions test/test_derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ end
for i in 1:100
z = rand_pm1(i)
x = transform_from(PM1(), t, z)
= linear_combination(b, θ) t
= linear_combination(b t, θ)
Dx = @inferred (derivatives(x, Val(D)))
@test Dx[0] == (x)
for d in 1:D
Expand All @@ -71,8 +71,9 @@ end
(f, x) -> DD(x2 -> f((x[1], x2, x[3])), x[2]),
(f, x) -> DD(x3 -> f((x[1], x[2], x3)), x[3]),
(f, x) -> DD(x1 -> DD(x2 -> f((x1, x2, x[3])), x[2]), x[1])]
θ = randn(dimension(b))
= linear_combination(b, θ) t
bt = b t
θ = randn(dimension(bt))
= linear_combination(bt, θ)
d = domain(b)
for i in 1:100
z = [rand_pm1() for _ in 1:N]
Expand Down
14 changes: 10 additions & 4 deletions test/test_generic_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,22 @@ end
@test_throws ArgumentError linear_combination(basis, bad_θ)
end

@testset "transformed linear combinations" begin
@testset "transformed bases and linear combinations" begin
N = 10
basis = Chebyshev(EndpointGrid(), N)
θ = randn(10)
t = BoundedLinear(1.0, 2.0)
@test domain(basis t) == domain(t)
@test dimension(basis t) == dimension(basis)
@test collect(grid(basis t)) ==
[transform_from(domain(basis), t, x) for x in grid(basis)]

θ = randn(10)
l1 = linear_combination(basis, θ)
l2 = l1 t
l2 = linear_combination(basis t, θ)
l3 = linear_combination(basis, θ) t
for _ in 1:20
x = rand() + 1.0
@test l1(transform_to(domain(basis), t, x)) == l2(x)
@test l1(transform_to(domain(basis), t, x)) == l2(x) == l3(x)
end
end

Expand Down

2 comments on commit eb8519a

@tpapp
Copy link
Owner Author

@tpapp tpapp commented on eb8519a Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/94630

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.14.0 -m "<description of version>" eb8519a7ef0fb036ad7a9afc519c629b915cb1db
git push origin v0.14.0

Please sign in to comment.