Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SparseArrayDOKs] Add setindex_maybe_grow! and macro @maybe_grow #1434

Merged
merged 10 commits into from
May 14, 2024
6 changes: 4 additions & 2 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -34,20 +35,20 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
TBLIS = "48530278-0828-4a49-9772-0f3830dfa1e9"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[extensions]
NDTensorsAMDGPUExt = "AMDGPU"
NDTensorsCUDAExt = "CUDA"
NDTensorscuTENSORExt = "cuTENSOR"
NDTensorsHDF5Ext = "HDF5"
NDTensorsMetalExt = "Metal"
NDTensorsOctavianExt = "Octavian"
NDTensorsTBLISExt = "TBLIS"
NDTensorscuTENSORExt = "cuTENSOR"

[compat]
Accessors = "0.1.33"
Expand All @@ -65,6 +66,7 @@ HDF5 = "0.14, 0.15, 0.16, 0.17"
HalfIntegers = "1"
InlineStrings = "1"
LinearAlgebra = "1.6"
MacroTools = "0.5"
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
MappedArrays = "0.4"
PackageExtensionCompat = "1"
Random = "1.6"
Expand Down
46 changes: 44 additions & 2 deletions NDTensors/src/lib/SparseArrayDOKs/src/sparsearraydok.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
using Accessors: @set
using Dictionaries: Dictionary, set!
using MacroTools: @capture
using ..SparseArrayInterface:
SparseArrayInterface, AbstractSparseArray, getindex_zero_function

# TODO: Parametrize by `data`?
struct SparseArrayDOK{T,N,Zero} <: AbstractSparseArray{T,N}
data::Dictionary{CartesianIndex{N},T}
dims::NTuple{N,Int}
dims::Ref{NTuple{N,Int}}
zero::Zero
function SparseArrayDOK{T,N,Zero}(data, dims::NTuple{N,Int}, zero) where {T,N,Zero}
return new{T,N,Zero}(data, Ref(dims), zero)
end
end

# Constructors
function SparseArrayDOK(data, dims::Tuple{Vararg{Int}}, zero)
return SparseArrayDOK{eltype(data),length(dims),typeof(zero)}(data, dims, zero)
end

function SparseArrayDOK{T,N,Zero}(dims::Tuple{Vararg{Int}}, zero) where {T,N,Zero}
return SparseArrayDOK{T,N,Zero}(default_data(T, N), dims, zero)
end
Expand Down Expand Up @@ -72,7 +80,7 @@ function SparseArrayDOK{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}, zero) w
end

# Base `AbstractArray` interface
Base.size(a::SparseArrayDOK) = a.dims
Base.size(a::SparseArrayDOK) = a.dims[]

SparseArrayInterface.getindex_zero_function(a::SparseArrayDOK) = a.zero
function SparseArrayInterface.set_getindex_zero_function(a::SparseArrayDOK, f)
Expand Down Expand Up @@ -104,3 +112,37 @@ SparseArrayDOK{T}(a::AbstractArray) where {T} = SparseArrayDOK{T,ndims(a)}(a)
function SparseArrayDOK{T,N}(a::AbstractArray) where {T,N}
return SparseArrayInterface.sparse_convert(SparseArrayDOK{T,N}, a)
end

function Base.resize!(a::SparseArrayDOK{<:Any,N}, new_size::NTuple{N,Integer}) where {N}
a.dims[] = new_size
return a
end

function setindex_maybe_grow!(a::SparseArrayDOK{<:Any,N}, value, I::Vararg{Int,N}) where {N}
if any(I .> size(a))
resize!(a, max.(I, size(a)))
end
a[I...] = value
return a
end

function is_setindex!_expr(expr::Expr)
return is_assignment_expr(expr) && is_getindex_expr(first(expr.args))
end
is_setindex!_expr(x) = false

is_getindex_expr(expr::Expr) = (expr.head === :ref)
is_getindex_expr(x) = false

is_assignment_expr(expr::Expr) = (expr.head === :(=))
is_assignment_expr(expr) = false

macro maybe_grow(expr)
if !is_setindex!_expr(expr)
error(
"@maybe_grow must be used with setindex! syntax (as @maybe_grow a[i,j,...] = value)"
)
end
@capture(expr, array_[indices__] = value_)
return :(setindex_maybe_grow!($(esc(array)), $value, $indices...))
end
34 changes: 33 additions & 1 deletion NDTensors/src/lib/SparseArrayDOKs/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
# Custom zero type
# Slicing

using Dictionaries: Dictionary
using Test: @test, @testset, @test_broken
using NDTensors.SparseArrayDOKs: SparseArrayDOK, SparseMatrixDOK
using NDTensors.SparseArrayDOKs:
SparseArrayDOKs, SparseArrayDOK, SparseMatrixDOK, @maybe_grow
using NDTensors.SparseArrayInterface: storage_indices, nstored
using SparseArrays: SparseMatrixCSC, nnz
@testset "SparseArrayDOK (eltype=$elt)" for elt in
Expand Down Expand Up @@ -94,5 +96,35 @@ using SparseArrays: SparseMatrixCSC, nnz
end
end
end
@testset "Maybe Grow Feature" begin
a = SparseArrayDOK{elt,2}((0, 0))
SparseArrayDOKs.setindex_maybe_grow!(a, 230, 2, 3)
@test size(a) == (2, 3)
@test a[2, 3] == 230
# Test @maybe_grow macro
@maybe_grow a[5, 5] = 550
@test size(a) == (5, 5)
@test a[2, 3] == 230
@test a[5, 5] == 550
# Test that size remains same
# if we set at an index smaller than
# the maximum size:
@maybe_grow a[3, 4] = 340
@test size(a) == (5, 5)
@test a[2, 3] == 230
@test a[5, 5] == 550
@test a[3, 4] == 340
# Test vector case
v = SparseArrayDOK{elt,1}((0,))
@maybe_grow v[5] = 50
@test size(v) == (5,)
@test v[5] == 50
end
@testset "Test Lower Level Constructor" begin
d = Dictionary{CartesianIndex{2},elt}()
a = SparseArrayDOK(d, (2, 2), zero(elt))
a[1, 2] = 12.0
@test a[1, 2] == 12.0
end
end
end
Loading