Skip to content

Commit

Permalink
Rel 0.3.0 - Fixed cmdline() and stan_run() warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
goedman committed Jan 28, 2024
1 parent 35b3d11 commit 4b42ea8
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "StanPathfinder"
uuid = "e8ee4b5e-54b2-4408-8575-c3c89e582a15"
authors = ["Rob J Goedman <[email protected]>"]
version = "0.2.0"
version = "0.3.0"

[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Expand Down
5 changes: 4 additions & 1 deletion examples/Bernoulli/bernoulli.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ data = Dict("N" => 10, "y" => [0, 1, 0, 1, 0, 0, 0, 0, 0, 1])
# Keep tmpdir across multiple runs to prevent re-compilation
tmpdir = joinpath(@__DIR__, "tmp")

sm = PathfinderModel("bernoulli", bernoulli_model, tmpdir)
sm = PathfinderModel("bernoulli", bernoulli_model)
rc = stan_pathfinder(sm; data, seed=rand(1:200000000, 1)[1], num_chains=2)

if all(success.(rc))
Expand All @@ -37,3 +37,6 @@ if all(success.(rc))
display(profile_df)

end

sm2 = PathfinderModel("bernoulli2", bernoulli_model, tmpdir)
rc2 = stan_pathfinder(sm2; data, seed=rand(1:200000000, 1)[1], num_chains=2)
2 changes: 1 addition & 1 deletion src/stanrun/cmdline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function cmdline(m::PathfinderModel, id)
cmd = `$cmd save_single_paths=$(m.save_single_paths)`
cmd = `$cmd max_lbfgs_iters=$(m.max_lbfgs_iters)`
cmd = `$cmd num_draws=$(m.num_draws)`
cmd = `$cmd num_draws=$(m.num_elbo_draws)`
cmd = `$cmd num_elbo_draws=$(m.num_elbo_draws)`

cmd = `$cmd id=$(id)`

Expand Down
5 changes: 2 additions & 3 deletions src/stanrun/stan_run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,8 @@ function stan_run(m::PathfinderModel, use_json=true; kwargs...)

handle_keywords!(m, kwargs)

if m.num_chains > 1 || m.num_threads > 1
@info "Currently running StanPathfinder with either \
num_chains>1 or num_threads>1 can lead to problematic results."
if m.num_threads > 1
@info "Currently running StanPathfinder with num_threads>1 can lead to problematic results."
end

setup_profiles(m, m.num_chains)
Expand Down
11 changes: 9 additions & 2 deletions src/stansamples/read_pathfinder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,26 @@ read_pathfinder(m::PathfinderModelodel)
"""
function read_pathfinder(m::PathfinderModel)

local a3d, index, idx, indvec
local a3d, idx, indvec

ftype = "chain"

for i in 1:m.num_chains
if isfile("$(m.output_base)_$(ftype)_$(i).csv")

file = "$(m.output_base)_$(ftype)_$(i).csv"
if !(isfile(file))
println("Found file $file not found!")
end

if isfile(file)
instream = open("$(m.output_base)_$(ftype)_$(i).csv")
skipchars(isspace, instream, linecomment='#')
line = Unicode.normalize(readline(instream), newline2lf=true)
idx = split(strip(line), ",")
index = [idx[k] for k in 1:length(idx)]
indvec = 1:length(index)
if i == 1
cnames = convert.(String, idx[indvec])
a3d = fill(0.0, m.num_draws, length(indvec), m.num_chains)
end
skipchars(isspace, instream, linecomment='#')
Expand Down

0 comments on commit 4b42ea8

Please sign in to comment.