Skip to content

Commit

Permalink
Avoid recomputing all medians for every parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Feb 11, 2024
1 parent 0fad7f9 commit 3579c55
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions src/plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,16 +206,10 @@ function _compute_plot_data(
show_hdii = true,
fill_q = true,
fill_hdi = false,
ordered = false
)
hdii = sort(hdi_prob; rev=true)

chain_dic = Dict(zip(quantile(chains)[2], quantile(chains)[5]))
sorted_chain = sort(collect(zip(values(chain_dic), keys(chain_dic))))
sorted_par = [sorted_chain[i][2] for i in 1:length(par_names)]
par = (ordered ? sorted_par : par_names)
hdii = sort(hdi_prob)

chain_sections = MCMCChains.group(chains, Symbol(par[i]))
chain_sections = MCMCChains.group(chains, Symbol(par_names[i]))
chain_vec = vec(chain_sections.value.data)
lower_hdi = [MCMCChains.hdi(chain_sections, prob = hdii[j])[:lower]
for j in 1:length(hdii)]
Expand All @@ -239,7 +233,7 @@ function _compute_plot_data(
min = minimum(k_density.density .+ h)
q_int = (show_qi ? [qs[1], chain_med, qs[2]] : [chain_med])

return par, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med,
return par_names, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med,
chain_mean, min, q_int
end

Expand All @@ -261,12 +255,16 @@ end
chn = p.args[1]
par_names = p.args[2]

if ordered
par_table_names, par_medians = summarize(chn[:, par_names, :], median)
par_names = par_table_names[sortperm(par_medians)]
end

for i in 1:length(par_names)
par, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med, chain_mean,
min, q_int = _compute_plot_data(i, chn, par_names; hdi_prob = hdi_prob, q = q,
spacer = spacer, _riser = _riser, show_mean = show_mean, show_median = show_median,
show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi,
ordered = ordered)
show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi)

yticks --> (length(par_names) > 1 ?
(_riser .+ ((1:length(par_names)) .- 1) .* spacer, string.(par)) : :default)
Expand Down Expand Up @@ -350,12 +348,16 @@ end
chn = p.args[1]
par_names = p.args[2]

if ordered
par_table_names, par_medians = summarize(chn[:, par_names, :], median)
par_names = par_table_names[sortperm(par_medians)]
end

for i in 1:length(par_names)
par, hdii, lower_hdi, upper_hdi, h, qs, k_density, x_int, val, chain_med, chain_mean,
min, q_int = _compute_plot_data(i, chn, par_names; hdi_prob = hdi_prob, q = q,
spacer = spacer, _riser = _riser, show_mean = show_mean, show_median = show_median,
show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi,
ordered = ordered)
show_qi = show_qi, show_hdii = show_hdii, fill_q = fill_q, fill_hdi = fill_hdi)

yticks --> (length(par_names) > 1 ?
(_riser .+ ((1:length(par_names)) .- 1) .* spacer, string.(par)) : :default)
Expand Down

0 comments on commit 3579c55

Please sign in to comment.