Skip to content

Commit

Permalink
[NestedPermutedDimsArrays] Fix setindex
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Nov 16, 2024
1 parent 549e7f1 commit b421ad1
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 10 deletions.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.71"
version = "0.3.72"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,45 @@
# Mostly copied from https://github.com/JuliaLang/julia/blob/master/base/permuteddimsarray.jl
# Like `PermutedDimsArrays` but singly nested, similar to `Adjoint` and `Transpose`
# (though those are fully recursive).
#=
TODO: Investigate replacing this with a `PermutedDimsArray` wrapped around a `MappedArrays.MappedArray`.
There are a few issues with that:
1. Just using a type alias leads to type piracy, for example the constructor is type piracy.
2. `setindex!(::NestedPermutedDimsArray, I...)` fails because no conversion is defined between `Array`
and `PermutedDimsArray`.
3. The type alias is tricky to define, ideally it would have similar type parameters to the current
`NestedPermutedDimsArrays.NestedPermutedDimsArray` definition which matches the type parameters
of `PermutedDimsArrays.PermutedDimsArray` but that seems to be difficult to achieve.
```julia
module NestedPermutedDimsArrays
using MappedArrays: MultiMappedArray, mappedarray
export NestedPermutedDimsArray
const NestedPermutedDimsArray{TT,T<:AbstractArray{TT},N,perm,iperm,AA<:AbstractArray{T}} = PermutedDimsArray{
PermutedDimsArray{TT,N,perm,iperm,T},
N,
perm,
iperm,
MultiMappedArray{
PermutedDimsArray{TT,N,perm,iperm,T},
N,
Tuple{AA},
Type{PermutedDimsArray{TT,N,perm,iperm,T}},
Type{PermutedDimsArray{TT,N,iperm,perm,T}},
},
}
function NestedPermutedDimsArray(a::AbstractArray, perm)
iperm = invperm(perm)
f = PermutedDimsArray{eltype(eltype(a)),ndims(a),perm,iperm,eltype(a)}
finv = PermutedDimsArray{eltype(eltype(a)),ndims(a),iperm,perm,eltype(a)}
return PermutedDimsArray(mappedarray(f, finv, a), perm)
end
end
```
=#
module NestedPermutedDimsArrays

import Base: permutedims, permutedims!
Expand Down Expand Up @@ -107,7 +146,7 @@ end
A::NestedPermutedDimsArray{T,N,perm,iperm}, val, I::Vararg{Int,N}
) where {T,N,perm,iperm}
@boundscheck checkbounds(A, I...)
@inbounds setindex!(A.parent, PermutedDimsArray(val, perm), genperm(I, iperm)...)
@inbounds setindex!(A.parent, PermutedDimsArray(val, iperm), genperm(I, iperm)...)
return val
end

Expand Down
16 changes: 8 additions & 8 deletions NDTensors/src/lib/NestedPermutedDimsArrays/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@ using Test: @test, @testset
Float32, Float64, Complex{Float32}, Complex{Float64}
)
a = map(_ -> randn(elt, 2, 3, 4), CartesianIndices((2, 3, 4)))
perm = (3, 2, 1)
perm = (3, 1, 2)
p = NestedPermutedDimsArray(a, perm)
T = PermutedDimsArray{elt,3,perm,invperm(perm),eltype(a)}
@test typeof(p) === NestedPermutedDimsArray{T,3,perm,invperm(perm),typeof(a)}
@test size(p) == (4, 3, 2)
@test size(p) == (4, 2, 3)
@test eltype(p) === T
for I in eachindex(p)
@test size(p[I]) == (4, 3, 2)
@test p[I] == permutedims(a[CartesianIndex(reverse(Tuple(I)))], perm)
@test size(p[I]) == (4, 2, 3)
@test p[I] == permutedims(a[CartesianIndex(map(i -> Tuple(I)[i], invperm(perm)))], perm)
end
x = randn(elt, 4, 3, 2)
p[3, 2, 1] = x
@test p[3, 2, 1] == x
@test a[1, 2, 3] == permutedims(x, perm)
x = randn(elt, 4, 2, 3)
p[3, 1, 2] = x
@test p[3, 1, 2] == x
@test a[1, 2, 3] == permutedims(x, invperm(perm))
end
end

0 comments on commit b421ad1

Please sign in to comment.