Skip to content

Commit

Permalink
Merge branch 'main' into RecursivePermutedDimsArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Nov 15, 2024
2 parents acd2746 + 10a6563 commit 43e2b41
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@ struct BlockSparseArray{
axes::Axes
end

const BlockSparseMatrix{T,A,Blocks,Axes} = BlockSparseArray{T,2,A,Blocks,Axes}
const BlockSparseVector{T,A,Blocks,Axes} = BlockSparseArray{T,1,A,Blocks,Axes}
# TODO: Can this definition be shortened?
const BlockSparseMatrix{T,A<:AbstractMatrix{T},Blocks<:AbstractMatrix{A},Axes<:Tuple{AbstractUnitRange,AbstractUnitRange}} = BlockSparseArray{
T,2,A,Blocks,Axes
}

# TODO: Can this definition be shortened?
const BlockSparseVector{T,A<:AbstractVector{T},Blocks<:AbstractVector{A},Axes<:Tuple{AbstractUnitRange}} = BlockSparseArray{
T,1,A,Blocks,Axes
}

function BlockSparseArray(
block_data::Dictionary{<:Block{N},<:AbstractArray{<:Any,N}},
Expand Down Expand Up @@ -68,10 +75,38 @@ function BlockSparseArray{T,N,A}(
return BlockSparseArray{T,N,A}(blocks, axes)
end

function BlockSparseArray{T,N,A}(
axes::Vararg{AbstractUnitRange,N}
) where {T,N,A<:AbstractArray{T,N}}
return BlockSparseArray{T,N,A}(axes)
end

function BlockSparseArray{T,N,A}(
dims::Tuple{Vararg{Vector{Int},N}}
) where {T,N,A<:AbstractArray{T,N}}
return BlockSparseArray{T,N,A}(blockedrange.(dims))
end

# Fix ambiguity error.
function BlockSparseArray{T,0,A}(axes::Tuple{}) where {T,A<:AbstractArray{T,0}}
blocks = default_blocks(A, axes)
return BlockSparseArray{T,0,A}(blocks, axes)
end

function BlockSparseArray{T,N,A}(
dims::Vararg{Vector{Int},N}
) where {T,N,A<:AbstractArray{T,N}}
return BlockSparseArray{T,N,A}(dims)
end

function BlockSparseArray{T,N}(axes::Tuple{Vararg{AbstractUnitRange,N}}) where {T,N}
return BlockSparseArray{T,N,default_arraytype(T, axes)}(axes)
end

function BlockSparseArray{T,N}(axes::Vararg{AbstractUnitRange,N}) where {T,N}
return BlockSparseArray{T,N}(axes)
end

function BlockSparseArray{T,0}(axes::Tuple{}) where {T}
return BlockSparseArray{T,0,default_arraytype(T, axes)}(axes)
end
Expand All @@ -80,6 +115,10 @@ function BlockSparseArray{T,N}(dims::Tuple{Vararg{Vector{Int},N}}) where {T,N}
return BlockSparseArray{T,N}(blockedrange.(dims))
end

function BlockSparseArray{T,N}(dims::Vararg{Vector{Int},N}) where {T,N}
return BlockSparseArray{T,N}(dims)
end

function BlockSparseArray{T}(dims::Tuple{Vararg{Vector{Int}}}) where {T}
return BlockSparseArray{T,length(dims)}(dims)
end
Expand All @@ -104,37 +143,25 @@ function BlockSparseArray{T}() where {T}
return BlockSparseArray{T}(())
end

function BlockSparseArray{T,N,A}(
::UndefInitializer, dims::Tuple
) where {T,N,A<:AbstractArray{T,N}}
return BlockSparseArray{T,N,A}(dims)
end

# undef
function BlockSparseArray{T,N}(
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange,N}}
) where {T,N}
return BlockSparseArray{T,N}(axes)
end

function BlockSparseArray{T,N}(
::UndefInitializer, dims::Tuple{Vararg{Vector{Int},N}}
) where {T,N}
return BlockSparseArray{T,N}(dims)
function BlockSparseArray{T,N,A,Blocks}(
::UndefInitializer, args...
) where {T,N,A<:AbstractArray{T,N},Blocks<:AbstractArray{A,N}}
return BlockSparseArray{T,N,A,Blocks}(args...)
end

function BlockSparseArray{T}(
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange}}
) where {T}
return BlockSparseArray{T}(axes)
function BlockSparseArray{T,N,A}(
::UndefInitializer, args...
) where {T,N,A<:AbstractArray{T,N}}
return BlockSparseArray{T,N,A}(args...)
end

function BlockSparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Vector{Int}}}) where {T}
return BlockSparseArray{T}(dims)
function BlockSparseArray{T,N}(::UndefInitializer, args...) where {T,N}
return BlockSparseArray{T,N}(args...)
end

function BlockSparseArray{T}(::UndefInitializer, dims::Vararg{Vector{Int}}) where {T}
return BlockSparseArray{T}(dims...)
function BlockSparseArray{T}(::UndefInitializer, args...) where {T}
return BlockSparseArray{T}(args...)
end

# Base `AbstractArray` interface
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ using ..SparseArrayInterface: perm, iperm, stored_length, sparse_zero!

blocksparse_blocks(a::AbstractArray) = error("Not implemented")

blockstype(a::AbstractArray) = blockstype(typeof(a))

function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N}
@boundscheck checkbounds(a, I...)
return a[findblockindex.(axes(a), I)...]
Expand Down
70 changes: 70 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@ using LinearAlgebra: Adjoint, dot, mul!, norm
using NDTensors.BlockSparseArrays:
@view!,
BlockSparseArray,
BlockSparseMatrix,
BlockSparseVector,
BlockView,
block_stored_length,
block_reshape,
block_stored_indices,
blockstype,
blocktype,
view!
using NDTensors.GPUArraysCoreExtensions: cpu
using NDTensors.SparseArrayInterface: stored_length
using NDTensors.SparseArrayDOKs: SparseArrayDOK, SparseMatrixDOK, SparseVectorDOK
using NDTensors.TensorAlgebra: contract
using Test: @test, @test_broken, @test_throws, @testset
include("TestBlockSparseArraysUtils.jl")
Expand Down Expand Up @@ -72,6 +77,71 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
ah = adjoint(a)
@test_broken [ah[Block(Tuple(it))] for it in eachindex(block_stored_indices(ah))] isa Vector
end
@testset "Constructors" begin
# BlockSparseMatrix
bs = ([2, 3], [3, 4])
for T in (
BlockSparseArray{elt},
BlockSparseArray{elt,2},
BlockSparseMatrix{elt},
BlockSparseArray{elt,2,Matrix{elt}},
BlockSparseMatrix{elt,Matrix{elt}},
## BlockSparseArray{elt,2,Matrix{elt},SparseMatrixDOK{Matrix{elt}}}, # TODO
## BlockSparseMatrix{elt,Matrix{elt},SparseMatrixDOK{Matrix{elt}}}, # TODO
)
for args in (
bs,
(bs,),
blockedrange.(bs),
(blockedrange.(bs),),
(undef, bs),
(undef, bs...),
(undef, blockedrange.(bs)),
(undef, blockedrange.(bs)...),
)
a = T(args...)
@test eltype(a) == elt
@test blocktype(a) == Matrix{elt}
@test blockstype(a) <: SparseMatrixDOK{Matrix{elt}}
@test blocklengths.(axes(a)) == ([2, 3], [3, 4])
@test iszero(a)
@test iszero(block_stored_length(a))
@test iszero(stored_length(a))
end
end

# BlockSparseVector
bs = ([2, 3],)
for T in (
BlockSparseArray{elt},
BlockSparseArray{elt,1},
BlockSparseVector{elt},
BlockSparseArray{elt,1,Vector{elt}},
BlockSparseVector{elt,Vector{elt}},
## BlockSparseArray{elt,1,Vector{elt},SparseVectorDOK{Vector{elt}}}, # TODO
## BlockSparseVector{elt,Vector{elt},SparseVectorDOK{Vector{elt}}}, # TODO
)
for args in (
bs,
(bs,),
blockedrange.(bs),
(blockedrange.(bs),),
(undef, bs),
(undef, bs...),
(undef, blockedrange.(bs)),
(undef, blockedrange.(bs)...),
)
a = T(args...)
@test eltype(a) == elt
@test blocktype(a) == Vector{elt}
@test blockstype(a) <: SparseVectorDOK{Vector{elt}}
@test blocklengths.(axes(a)) == ([2, 3],)
@test iszero(a)
@test iszero(block_stored_length(a))
@test iszero(stored_length(a))
end
end
end
@testset "Basics" begin
a = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
@allowscalar @test a == dev(
Expand Down

0 comments on commit 43e2b41

Please sign in to comment.