From e97ab9b42760d33a6c88796c6541c1c8dcf684e2 Mon Sep 17 00:00:00 2001 From: Anshul Singhvi Date: Tue, 22 Oct 2024 17:17:52 -0700 Subject: [PATCH 1/2] Implement `collect_similar` like `collect` for DiskGenerators --- src/generator.jl | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/generator.jl b/src/generator.jl index be47e62..8619deb 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -47,6 +47,31 @@ function Base.collect(itr::DiskGenerator{<:AbstractArray{<:Any,N}}) where {N} return dest end +# Warning: this is not public API! +function Base.collect_similar(A::AbstractArray, itr::DiskGenerator{<:AbstractArray{<:Any,N}}) where {N} + y = iterate(itr) + shp = axes(itr.iter) + if y === nothing + et = Base.@default_eltype(itr) + return similar(A, et, shp) + end + v1, st = y + dest = similar(A, typeof(v1), shp) + i = y + for I in eachindex(itr.iter) + if i isa Nothing # Mainly to keep JET clean + error( + "Should not be reached: iterator is shorter than its `eachindex` iterator" + ) + else + dest[I] = first(i) + i = iterate(itr, last(i)) + end + end + return dest + +end + macro implement_generator(t) t = esc(t) quote From 9a33e9afc33904484ffe846560db578ac867e2ac Mon Sep 17 00:00:00 2001 From: Anshul Singhvi Date: Tue, 22 Oct 2024 17:26:01 -0700 Subject: [PATCH 2/2] Add a test --- test/runtests.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index db8dd1c..5c9a487 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -953,3 +953,17 @@ end @test getindex_count(A) == 0 end +@testset "Map over indices correctly" begin + # This is a regression test for issue #144 + # `map` should always work over the correct indices, + # especially since we overload generators to `DiskArrayGenerator`. + + data = [i+j for i in 1:200, j in 1:100] + da = AccessCountDiskArray(data, chunksize=(10,10)) + @test map(identity, da) == data + @test all(map(identity, da) .== data) + + # Make sure that type inference works + @inferred Matrix{Int} map(identity, da) + @inferred Matrix{Float64} map(x -> x * 5.0, da) +end