Skip to content

Commit

Permalink
Implement collect_similar like collect for DiskGenerators (#198)
Browse files Browse the repository at this point in the history
* Implement `collect_similar` like `collect` for DiskGenerators

* Add a test
  • Loading branch information
asinghvi17 authored Oct 23, 2024
1 parent c41b193 commit 1f1f075
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
25 changes: 25 additions & 0 deletions src/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,31 @@ function Base.collect(itr::DiskGenerator{<:AbstractArray{<:Any,N}}) where {N}
return dest
end

# Warning: this is not public API!
function Base.collect_similar(A::AbstractArray, itr::DiskGenerator{<:AbstractArray{<:Any,N}}) where {N}
y = iterate(itr)
shp = axes(itr.iter)
if y === nothing
et = Base.@default_eltype(itr)
return similar(A, et, shp)
end
v1, st = y
dest = similar(A, typeof(v1), shp)
i = y
for I in eachindex(itr.iter)
if i isa Nothing # Mainly to keep JET clean
error(
"Should not be reached: iterator is shorter than its `eachindex` iterator"
)
else
dest[I] = first(i)
i = iterate(itr, last(i))
end
end
return dest

end

macro implement_generator(t)
t = esc(t)
quote
Expand Down
14 changes: 14 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -953,3 +953,17 @@ end
@test getindex_count(A) == 0
end

@testset "Map over indices correctly" begin
# This is a regression test for issue #144
# `map` should always work over the correct indices,
# especially since we overload generators to `DiskArrayGenerator`.

data = [i+j for i in 1:200, j in 1:100]
da = AccessCountDiskArray(data, chunksize=(10,10))
@test map(identity, da) == data
@test all(map(identity, da) .== data)

# Make sure that type inference works
@inferred Matrix{Int} map(identity, da)
@inferred Matrix{Float64} map(x -> x * 5.0, da)
end

0 comments on commit 1f1f075

Please sign in to comment.