Skip to content

Commit

Permalink
Add a missing domain method for coordinate domains, with tests. (#49)
Browse files Browse the repository at this point in the history
* Add a missing domain method for coordinate domains, with tests.

* fix AbstractVector argument inference
  • Loading branch information
tpapp authored Jan 18, 2024
1 parent 9ec9c6a commit 9ec62f7
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 7 deletions.
17 changes: 11 additions & 6 deletions src/transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ end

domain_kind(::Type{<:CoordinateTransformations}) = :multivariate

function domain(coordinate_transformations::CoordinateTransformations)
coordinate_domains(map(domain, coordinate_transformations.transformations))
end

function Base.Tuple(coordinate_transformations::CoordinateTransformations)
coordinate_transformations.transformations
end
Expand Down Expand Up @@ -109,9 +113,9 @@ function transform_to(domain::CoordinateDomains, ct::CoordinateTransformations,
map((d, t, x) -> transform_to(d, t, x), domains, transformations, x)
end

function transform_to(domain::CoordinateDomains, ct::CoordinateTransformations,
x::AbstractVector)
SVector(transform_to(domain, ct, Tuple(x)))
function transform_to(domain::CoordinateDomains{T}, ct::CoordinateTransformations,
x::AbstractVector) where T
SVector(transform_to(domain, ct, _ntuple_like(T, x)))
end

function transform_to(domain::CoordinateDomains, ct::CoordinateTransformations, ∂x::∂Input)
Expand All @@ -131,9 +135,9 @@ function transform_from(domain::CoordinateDomains, ct::CoordinateTransformations
map((d, t, x) -> transform_from(d, t, x), domains, transformations, x)
end

function transform_from(domain::CoordinateDomains, ct::CoordinateTransformations,
x::AbstractVector)
SVector(transform_from(domain, ct, Tuple(x)))
function transform_from(domain::CoordinateDomains{T}, ct::CoordinateTransformations,
x::AbstractVector) where {T}
SVector(transform_from(domain, ct, _ntuple_like(T, x)))
end

####
Expand Down Expand Up @@ -266,6 +270,7 @@ end

function domain(t::SemiInfRational)
(; L, A) = t
A = float(A)
= oftype(A, Inf)
L > 0 ? UnivariateDomain(A, ∞) : UnivariateDomain(-∞, A)
end
Expand Down
13 changes: 13 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,16 @@ struct SubScript
end

Base.print(io::IO, s::SubScript) = print_number(io, _SUBSCRIPT_DIGITS, s.i)

"""
$(SIGNATURES)
If `T <: NTuple{N}`, convert `v` into an `NTuple{N}`.
Used for ingesting `::AbstractVector` arguments in contexts where an `NTuple` or
`SVector` is preferred.
"""
function _ntuple_like(::Type{T}, v::AbstractVector) where {N,T<:NTuple{N}}
@argcheck length(v) == N
NTuple{N}(v)
end
21 changes: 20 additions & 1 deletion test/test_generic_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ end
@test_throws ArgumentError linear_combination(basis, bad_θ)
end

@testset "transformed bases and linear combinations" begin
@testset "transformed bases and linear combinations (univariate)" begin
N = 10
basis = Chebyshev(EndpointGrid(), N)
t = BoundedLinear(1.0, 2.0)
Expand All @@ -48,6 +48,25 @@ end
end
end

@testset "transformed bases and linear combinations (bivariate)" begin
basis0 = smolyak_basis(Chebyshev, InteriorGrid(), SmolyakParameters(2, 2), Val(2))
t = coordinate_transformations(BoundedLinear(1.0, 2.0), SemiInfRational(0, 1))
basis = basis0 t
@test domain(basis t) == domain(t)
@test dimension(basis t) == dimension(basis)
@test collect(grid(basis)) ==
[transform_from(domain(basis0), t, x) for x in grid(basis0)]

θ = randn(dimension(basis))
l1 = linear_combination(basis0, θ)
l2 = linear_combination(basis, θ)
l3 = linear_combination(basis0, θ) t
for _ in 1:20
x = rand(2) .+ 1.0
@test l1(transform_to(domain(basis0), t, x)) == l2(x) == l3(x)
end
end

@testset "subset fallback" begin
@test !is_subset_basis(Chebyshev(InteriorGrid(), 4), # just test the fallback method
smolyak_basis(Chebyshev, InteriorGrid(), SmolyakParameters(2, 2), 2))
Expand Down

0 comments on commit 9ec62f7

Please sign in to comment.