Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SparseArrayInterface] NestedPermutedDimsArray support #1590

Merged
merged 7 commits into from
Nov 15, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end

# TODO: Make this into a generic definition of all `AbstractArray`?
function SparseArrayInterface.stored_indices(
a::PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray}
a::AnyPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray}
)
return Iterators.map(
I -> CartesianIndex(map(i -> I[i], perm(a))), stored_indices(parent(a))
Expand All @@ -41,7 +41,7 @@ end

# TODO: Make this into a generic definition of all `AbstractArray`?
function SparseArrayInterface.sparse_storage(
a::PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray}
a::AnyPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray}
)
return sparse_storage(parent(a))
end
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
using ..NestedPermutedDimsArrays: NestedPermutedDimsArray

## PermutedDimsArray

perm(::PermutedDimsArray{<:Any,<:Any,P}) where {P} = P
iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,IP}) where {IP} = IP
const AnyPermutedDimsArray{T,N,perm,iperm,P} = Union{
PermutedDimsArray{T,N,perm,iperm,P},NestedPermutedDimsArray{T,N,perm,iperm,P}
}

# TODO: Use `TypeParameterAccessors`.
perm(::AnyPermutedDimsArray{<:Any,<:Any,Perm}) where {Perm} = Perm
iperm(::AnyPermutedDimsArray{<:Any,<:Any,<:Any,IPerm}) where {IPerm} = IPerm

# TODO: Use `Base.PermutedDimsArrays.genperm` or
# https://github.com/jipolanco/StaticPermutations.jl?
genperm(v, perm) = map(j -> v[j], perm)
genperm(v::CartesianIndex, perm) = CartesianIndex(map(j -> Tuple(v)[j], perm))

function storage_index_to_index(a::PermutedDimsArray, I)
function storage_index_to_index(a::AnyPermutedDimsArray, I)
return genperm(storage_index_to_index(parent(a), I), perm(a))
end

function index_to_storage_index(
a::PermutedDimsArray{<:Any,N}, I::CartesianIndex{N}
a::AnyPermutedDimsArray{<:Any,N}, I::CartesianIndex{N}
) where {N}
return index_to_storage_index(parent(a), genperm(I, perm(a)))
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
module AbstractSparseArrays
using ArrayLayouts: ArrayLayouts, MatMulMatAdd, MemoryLayout, MulAdd
using NDTensors.SparseArrayInterface: SparseArrayInterface, AbstractSparseArray
using NDTensors.SparseArrayInterface: SparseArrayInterface, AbstractSparseArray, Zero

struct SparseArray{T,N} <: AbstractSparseArray{T,N}
struct SparseArray{T,N,Zero} <: AbstractSparseArray{T,N}
data::Vector{T}
dims::Tuple{Vararg{Int,N}}
index_to_dataindex::Dict{CartesianIndex{N},Int}
dataindex_to_index::Vector{CartesianIndex{N}}
zero::Zero
end
function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}) where {T,N}
return SparseArray{T,N}(
T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}()
function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}; zero=Zero()) where {T,N}
return SparseArray{T,N,typeof(zero)}(
T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}(), zero
)
end
SparseArray{T,N}(dims::Vararg{Int,N}) where {T,N} = SparseArray{T,N}(dims)
SparseArray{T}(dims::Tuple{Vararg{Int}}) where {T} = SparseArray{T,length(dims)}(dims)
function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}) where {T}
return SparseArray{T}(dims)
function SparseArray{T,N}(dims::Vararg{Int,N}; kwargs...) where {T,N}
return SparseArray{T,N}(dims; kwargs...)
end
SparseArray{T}(dims::Vararg{Int}) where {T} = SparseArray{T}(dims)
function SparseArray{T}(dims::Tuple{Vararg{Int}}; kwargs...) where {T}
return SparseArray{T,length(dims)}(dims; kwargs...)
end
function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}; kwargs...) where {T}
return SparseArray{T}(dims; kwargs...)
end
SparseArray{T}(dims::Vararg{Int}; kwargs...) where {T} = SparseArray{T}(dims; kwargs...)

# ArrayLayouts interface
struct SparseLayout <: MemoryLayout end
Expand All @@ -41,6 +46,7 @@ function Base.similar(a::SparseArray, elt::Type, dims::Tuple{Vararg{Int}})
end

# Minimal interface
SparseArrayInterface.getindex_zero_function(a::SparseArray) = a.zero
SparseArrayInterface.sparse_storage(a::SparseArray) = a.data
function SparseArrayInterface.index_to_storage_index(
a::SparseArray{<:Any,N}, I::CartesianIndex{N}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
module SparseArrays
using LinearAlgebra: LinearAlgebra
using NDTensors.SparseArrayInterface: SparseArrayInterface
using NDTensors.SparseArrayInterface: SparseArrayInterface, Zero

struct SparseArray{T,N} <: AbstractArray{T,N}
struct SparseArray{T,N,Zero} <: AbstractArray{T,N}
data::Vector{T}
dims::Tuple{Vararg{Int,N}}
index_to_dataindex::Dict{CartesianIndex{N},Int}
dataindex_to_index::Vector{CartesianIndex{N}}
zero::Zero
end
function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}) where {T,N}
return SparseArray{T,N}(
T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}()
function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}; zero=Zero()) where {T,N}
return SparseArray{T,N,typeof(zero)}(
T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}(), zero
)
end
SparseArray{T,N}(dims::Vararg{Int,N}) where {T,N} = SparseArray{T,N}(dims)
SparseArray{T}(dims::Tuple{Vararg{Int}}) where {T} = SparseArray{T,length(dims)}(dims)
function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}) where {T}
return SparseArray{T}(dims)
function SparseArray{T,N}(dims::Vararg{Int,N}; kwargs...) where {T,N}
return SparseArray{T,N}(dims; kwargs...)
end
SparseArray{T}(dims::Vararg{Int}) where {T} = SparseArray{T}(dims)
function SparseArray{T}(dims::Tuple{Vararg{Int}}; kwargs...) where {T}
return SparseArray{T,length(dims)}(dims; kwargs...)
end
function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}; kwargs...) where {T}
return SparseArray{T}(dims; kwargs...)
end
SparseArray{T}(dims::Vararg{Int}; kwargs...) where {T} = SparseArray{T}(dims; kwargs...)

# LinearAlgebra interface
function LinearAlgebra.mul!(
Expand Down Expand Up @@ -53,6 +58,7 @@ function Base.fill!(a::SparseArray, value)
end

# Minimal interface
SparseArrayInterface.getindex_zero_function(a::SparseArray) = a.zero
SparseArrayInterface.sparse_storage(a::SparseArray) = a.data
function SparseArrayInterface.index_to_storage_index(
a::SparseArray{<:Any,N}, I::CartesianIndex{N}
Expand All @@ -79,6 +85,33 @@ function SparseArrayInterface.stored_indices(
)
end

# TODO: Make this into a generic definition of all `AbstractArray`?
using NDTensors.SparseArrayInterface: sparse_storage
function SparseArrayInterface.sparse_storage(
a::PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:SparseArray}
)
return sparse_storage(parent(a))
end

# TODO: Make this into a generic definition of all `AbstractArray`?
using NDTensors.NestedPermutedDimsArrays: NestedPermutedDimsArray
function SparseArrayInterface.stored_indices(
a::NestedPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:SparseArray}
)
return Iterators.map(
I -> CartesianIndex(map(i -> I[i], perm(a))), stored_indices(parent(a))
)
end

# TODO: Make this into a generic definition of all `AbstractArray`?
using NDTensors.NestedPermutedDimsArrays: NestedPermutedDimsArray
using NDTensors.SparseArrayInterface: sparse_storage
function SparseArrayInterface.sparse_storage(
a::NestedPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:SparseArray}
)
return sparse_storage(parent(a))
end

# Empty the storage, helps with efficiency in `map!` to drop
# zeros.
function SparseArrayInterface.dropall!(a::SparseArray)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@eval module $(gensym())
using LinearAlgebra: dot, mul!, norm
using NDTensors.SparseArrayInterface: SparseArrayInterface
using NDTensors.NestedPermutedDimsArrays: NestedPermutedDimsArray
include("SparseArrayInterfaceTestUtils/SparseArrayInterfaceTestUtils.jl")
using .SparseArrayInterfaceTestUtils.AbstractSparseArrays: AbstractSparseArrays
using .SparseArrayInterfaceTestUtils.SparseArrays: SparseArrays
Expand Down Expand Up @@ -224,6 +225,44 @@ using Test: @test, @testset
end
end

a = SparseArray{elt}(2, 3)
a[1, 2] = 12
b = PermutedDimsArray(a, (2, 1))
@test size(b) == (3, 2)
@test axes(b) == (1:3, 1:2)
@test SparseArrayInterface.sparse_storage(b) == elt[12]
@test SparseArrayInterface.stored_length(b) == 1
@test collect(SparseArrayInterface.stored_indices(b)) == [CartesianIndex(2, 1)]
@test !iszero(b)
@test !iszero(norm(b))
for I in eachindex(b)
if I == CartesianIndex(2, 1)
@test b[I] == 12
else
@test iszero(b[I])
end
end

a = SparseArray{Matrix{elt}}(
2, 3; zero=(a, I) -> (z = similar(eltype(a), 2, 3); fill!(z, false); z)
)
a[1, 2] = randn(elt, 2, 3)
b = NestedPermutedDimsArray(a, (2, 1))
@test size(b) == (3, 2)
@test axes(b) == (1:3, 1:2)
@test SparseArrayInterface.sparse_storage(b) == [a[1, 2]]
@test SparseArrayInterface.stored_length(b) == 1
@test collect(SparseArrayInterface.stored_indices(b)) == [CartesianIndex(2, 1)]
@test !iszero(b)
@test !iszero(norm(b))
for I in eachindex(b)
if I == CartesianIndex(2, 1)
@test b[I] == permutedims(a[1, 2], (2, 1))
else
@test iszero(b[I])
end
end

a = SparseArray{elt}(2, 3)
a[1, 2] = 12
b = randn(elt, 2, 3)
Expand Down
Loading