Skip to content

Commit

Permalink
replaced Woodbury with specialized GradientKernelElement, decreasing …
Browse files Browse the repository at this point in the history
…memory allocations drastically
  • Loading branch information
SebastianAment committed Apr 28, 2022
1 parent c420c1c commit 53f48d5
Show file tree
Hide file tree
Showing 13 changed files with 244 additions and 130 deletions.
28 changes: 14 additions & 14 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ version = "5.0.7"

[[deps.ArrayLayouts]]
deps = ["FillArrays", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "8b921542ad44cba67f1487e2226446597e0a90af"
git-tree-sha1 = "c23473c60476e62579c077534b9643ec400f792b"
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
version = "0.8.5"
version = "0.8.6"

[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
Expand Down Expand Up @@ -176,9 +176,9 @@ version = "1.0.3"

[[deps.DiffRules]]
deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "dd933c4ef7b4c270aacd4eb88fa64c147492acf0"
git-tree-sha1 = "28d605d9a0ac17118fe2c5e9ce0fbb76c3ceb120"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.10.0"
version = "1.11.0"

[[deps.Distances]]
deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"]
Expand Down Expand Up @@ -237,9 +237,9 @@ version = "0.1.1"

[[deps.ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "40d1546a45abd63490569695a86a2d93c2021e54"
git-tree-sha1 = "34e6147e7686a101c245f12dba43b743c7afda96"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.26"
version = "0.10.27"

[[deps.FunctionWrappers]]
git-tree-sha1 = "241552bc2209f0fa068b6415b1942cc0aa486bcc"
Expand Down Expand Up @@ -396,9 +396,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[deps.LogExpFunctions]]
deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "a970d55c2ad8084ca317a4658ba6ce99b7523571"
git-tree-sha1 = "44a7b7bb7dd1afe12bac119df6a7e540fa2c96bc"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.12"
version = "0.3.13"

[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand All @@ -410,9 +410,9 @@ version = "0.4.10"

[[deps.MLUtils]]
deps = ["ChainRulesCore", "DelimitedFiles", "FLoops", "FoldsThreads", "Random", "ShowCases", "Statistics", "StatsBase"]
git-tree-sha1 = "32eeb46fa393ae36a4127c9442ade478c8d01117"
git-tree-sha1 = "202617a5a49a8b5f3b4abf96621f2519b1592c74"
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
version = "0.2.3"
version = "0.2.4"

[[deps.MPC_jll]]
deps = ["Artifacts", "GMP_jll", "JLLWrappers", "Libdl", "MPFR_jll", "Pkg"]
Expand Down Expand Up @@ -533,9 +533,9 @@ version = "1.4.1"

[[deps.Parsers]]
deps = ["Dates"]
git-tree-sha1 = "3b429f37de37f1fc603cc1de4a799dc7fbe4c0b6"
git-tree-sha1 = "1285416549ccfcdf0c50d4997a94331e88d68413"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "2.3.0"
version = "2.3.1"

[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
Expand Down Expand Up @@ -658,9 +658,9 @@ version = "0.1.14"

[[deps.Static]]
deps = ["IfElse"]
git-tree-sha1 = "2114b1d8517764a8c4625a2e97f40640c7a301a7"
git-tree-sha1 = "b1f1f60bf4f25d8b374480fb78c7b9785edf95fd"
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
version = "0.6.1"
version = "0.6.2"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ BesselK = "0.3"
BlockFactorizations = "1.2.1"
DiffResults = "1.0"
FillArrays = "0.12, 0.13"
Flux = "0.13"
ForwardDiff = "0.10"
Functors = "0.2"
IterativeSolvers = "0.9"
KroneckerProducts = "1.0"
LazyArrays = "0.22"
Expand Down
Binary file removed src/.DS_Store
Binary file not shown.
32 changes: 16 additions & 16 deletions src/barneshut.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function BarnesHutFactorization(k, x, y = x, D = nothing; θ::Real = 1/4, leafsi
# w = zeros(length(m))
# i = zeros(Bool, m)
# WT, BT = typeof(w), typeof(i)
T = gramian_eltype(k, xs, ys)
T = gramian_eltype(k, xs[1], ys[1])
BarnesHutFactorization{T, KT, XT, YT, TT, DT, RT}(k, xs, ys, Tree, D, θ) #, w, i)
end
function BarnesHutFactorization(G::Gramian, θ::Real = 1/2; leafsize::Int = BARNES_HUT_DEFAULT_LEAFSIZE)
Expand All @@ -49,7 +49,7 @@ function LinearAlgebra.mul!(y::AbstractVector, F::BarnesHutFactorization, x::Abs
taylor!(y, F, x, α, β)
end
end
function Base.:*(F::BarnesHutFactorization, x::AbstractVector)
function Base.:*(F::BarnesHutFactorization{<:Number}, x::AbstractVector{<:Number})
T = promote_type(eltype(F), eltype(x))
y = zeros(T, size(F, 1))
mul!(y, F, x)
Expand Down Expand Up @@ -148,45 +148,45 @@ end

############################# centers of mass ##################################
# this is a weighted sum, could be generalized to incorporate node_sums
function compute_centers_of_mass(x::AbstractVector, w::AbstractVector, T::BallTree)
function compute_centers_of_mass(w::AbstractVector, x::AbstractVector, T::BallTree)
D = eltype(x) <: StaticVector ? length(eltype(x)) : length(x[1]) # if x is static vector
com = [zero(MVector{D, Float64}) for _ in 1:length(T.hyper_spheres)]
compute_centers_of_mass!(com, x, w, T)
compute_centers_of_mass!(com, w, x, T)
end

function compute_centers_of_mass(F::BarnesHutFactorization, w::AbstractVector)
compute_centers_of_mass(F.y, w, F.Tree)
compute_centers_of_mass(w, F.y, F.Tree)
end

function compute_centers_of_mass!(com::AbstractVector, x::AbstractVector, w::AbstractVector, T::BallTree)
function compute_centers_of_mass!(com::AbstractVector, w::AbstractVector, x::AbstractVector, T::BallTree)
abs_w = abs.(w)
weighted_node_sums!(com, x, abs_w, T)
weighted_node_sums!(com, abs_w, x, T)
sum_w = node_sums(abs_w, T)
ε = eps(eltype(w)) # ensuring division by zero it not a problem
@. com ./= sum_w + ε
end

node_sums(x::AbstractVector, T::BallTree) = weighted_node_sums(x, Ones(length(x)), T)
node_sums(x::AbstractVector, T::BallTree) = weighted_node_sums(Ones(length(x)), x, T)
function node_sums!(sums, x::AbstractVector, T::BallTree)
weighted_node_sums!(sums, x, Ones(length(x)), T)
weighted_node_sums!(sums, Ones(length(x)), x, T)
end

function weighted_node_sums(x::AbstractVector, w::AbstractVector, T::BallTree, index::Int = 1)
function weighted_node_sums(w::AbstractVector, x::AbstractVector, T::BallTree, index::Int = 1)
length(x) == 0 && return zero(eltype(x))
sums = zeros(typeof(w[1]'x[1]), length(T.hyper_spheres))
weighted_node_sums!(sums, x, w, T)
sums = fill(zero(w[1]'x[1]), length(T.hyper_spheres))
weighted_node_sums!(sums, w, x, T)
end

# NOTE: x should either be vector of numbers or vector of static arrays
function weighted_node_sums!(sums::AbstractVector, x::AbstractVector,
w::AbstractVector{<:Number}, T::BallTree, index::Int = 1)
function weighted_node_sums!(sums::AbstractVector, w::AbstractVector,
x::AbstractVector, T::BallTree, index::Int = 1)
if isleaf(T.tree_data.n_internal_nodes, index)
i = get_leaf_range(T.tree_data, index)
wi, xi = @views w[T.indices[i]], x[T.indices[i]]
sums[index] = wi'xi
else
task = @spawn weighted_node_sums!(sums, x, w, T, getleft(index))
weighted_node_sums!(sums, x, w, T, getright(index))
task = @spawn weighted_node_sums!(sums, w, x, T, getleft(index))
weighted_node_sums!(sums, w, x, T, getright(index))
wait(task)
sums[index] = sums[getleft(index)] + sums[getright(index)]
end
Expand Down
Loading

0 comments on commit 53f48d5

Please sign in to comment.