This repository has been archived by the owner on Sep 10, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,575 additions
and
251 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "TuringModels" | ||
uuid = "ead7e11d-4ba5-55c3-9d74-177ea73ef1fd" | ||
authors = ["StatisticalRethinking <[email protected]>"] | ||
version = "0.2.0" | ||
version = "0.3.0" | ||
|
||
[deps] | ||
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
|
||
### m8.1stan | ||
|
||
m8.1stan is the first model in the Statistical Rethinking book (pp. 249) using Stan. | ||
|
||
Here we will use Turing's NUTS support, which is currently (2018) the originalNUTS by [Hoffman & Gelman]( http://www.stat.columbia.edu/~gelman/research/published/nuts.pdf) and not the one that's in Stan 2.18.2, i.e., Appendix A.5 in: https://arxiv.org/abs/1701.02434 | ||
|
||
The TuringModels pkg imports modules such as CSV and DataFrames | ||
|
||
|
||
```julia | ||
using TuringModels | ||
|
||
Turing.setadbackend(:reverse_diff); | ||
Turing.turnprogress(false) | ||
``` | ||
|
||
loaded | ||
|
||
|
||
┌ Info: [Turing]: global PROGRESS is set as false | ||
└ @ Turing /Users/rob/.julia/packages/Turing/r03H1/src/Turing.jl:24 | ||
|
||
|
||
|
||
|
||
|
||
false | ||
|
||
|
||
|
||
Read in the `rugged` data as a DataFrame | ||
|
||
|
||
```julia | ||
d = CSV.read(rel_path("..", "data", "rugged.csv"), delim=';'); | ||
``` | ||
|
||
Show size of the DataFrame (should be 234x51) | ||
|
||
|
||
```julia | ||
size(d) | ||
``` | ||
|
||
|
||
|
||
|
||
(234, 51) | ||
|
||
|
||
|
||
Apply log() to each element in rgdppc_2000 column and add it as a new column | ||
|
||
|
||
```julia | ||
d = hcat(d, map(log, d[Symbol("rgdppc_2000")])); | ||
``` | ||
|
||
Rename our col x1 => log_gdp | ||
|
||
|
||
```julia | ||
rename!(d, :x1 => :log_gdp); | ||
``` | ||
|
||
Now we need to drop every row where rgdppc_2000 == missing | ||
|
||
When this (https://github.com/JuliaData/DataFrames.jl/pull/1546) hits DataFrame it'll be conceptually easier: i.e., completecases!(d, :rgdppc_2000) | ||
|
||
|
||
```julia | ||
notisnan(e) = !ismissing(e) | ||
dd = d[map(notisnan, d[:rgdppc_2000]), :]; | ||
``` | ||
|
||
Updated DataFrame dd size (should equal 170 x 52) | ||
|
||
|
||
```julia | ||
size(dd) | ||
``` | ||
|
||
|
||
|
||
|
||
(170, 52) | ||
|
||
|
||
|
||
Define the Turing model | ||
|
||
|
||
```julia | ||
@model m8_1stan(y, x₁, x₂) = begin | ||
σ ~ Truncated(Cauchy(0, 2), 0, Inf) | ||
βR ~ Normal(0, 10) | ||
βA ~ Normal(0, 10) | ||
βAR ~ Normal(0, 10) | ||
α ~ Normal(0, 100) | ||
|
||
for i ∈ 1:length(y) | ||
y[i] ~ Normal(α + βR * x₁[i] + βA * x₂[i] + βAR * x₁[i] * x₂[i], σ) | ||
end | ||
end; | ||
``` | ||
|
||
Test to see that the model is sane. Use 2000 for now, as in the book. | ||
Need to set the same stepsize and adapt_delta as in Stan... | ||
|
||
Use Turing mcmc | ||
|
||
|
||
```julia | ||
posterior = sample(m8_1stan(dd[:log_gdp], dd[:rugged], dd[:cont_africa]), | ||
Turing.NUTS(2000, 1000, 0.95)); | ||
``` | ||
|
||
┌ Info: [Turing] looking for good initial eps... | ||
└ @ Turing.Inference /Users/rob/.julia/packages/Turing/r03H1/src/inference/support/hmc_core.jl:240 | ||
[NUTS{Turing.Core.FluxTrackerAD,Union{}}] found initial ϵ: 0.1 | ||
└ @ Turing.Inference /Users/rob/.julia/packages/Turing/r03H1/src/inference/support/hmc_core.jl:235 | ||
┌ Info: Adapted ϵ = 0.026392788491123437, std = [1.0, 1.0, 1.0, 1.0, 1.0]; 1000 iterations is used for adaption. | ||
└ @ Turing.Inference /Users/rob/.julia/packages/Turing/r03H1/src/inference/adapt/adapt.jl:91 | ||
|
||
|
||
[NUTS] Finished with | ||
Running time = 223.3260227410002; | ||
#lf / sample = 0.0; | ||
#evals / sample = 46.4005; | ||
pre-cond. metric = [1.0, 1.0, 1.0, 1.0, 1.0]. | ||
|
||
|
||
Fix the inclusion of adaptation samples | ||
|
||
|
||
```julia | ||
draws = 1001:2000 | ||
posterior2 = Chains(posterior[draws,:,:], :parameters) | ||
``` | ||
|
||
|
||
|
||
|
||
Object of type Chains, with data of type 1000×5×1 Array{Union{Missing, Float64},3} | ||
|
||
Log evidence = 0.0 | ||
Iterations = 1001:2000 | ||
Thinning interval = 1 | ||
Chains = Chain1 | ||
Samples per chain = 1000 | ||
parameters = σ, βAR, βA, α, βR | ||
|
||
parameters | ||
Mean SD Naive SE MCSE ESS | ||
α 9.2258 0.1457 0.0046 0.0097 224.9403 | ||
βA -1.9580 0.2372 0.0075 0.0164 209.5267 | ||
βAR 0.3969 0.1330 0.0042 0.0100 175.8492 | ||
βR -0.2038 0.0797 0.0025 0.0055 212.8302 | ||
σ 0.9500 0.0509 0.0016 0.0017 916.8077 | ||
|
||
|
||
|
||
|
||
|
||
Example of a Turing run simulation output | ||
|
||
Here's the ulam() output from rethinking | ||
|
||
|
||
```julia | ||
m8_1s_cmdstan = " | ||
Iterations = 1:1000 | ||
Thinning interval = 1 | ||
Chains = 1,2,3,4 | ||
Samples per chain = 1000 | ||
Empirical Posterior Estimates: | ||
Mean SD Naive SE MCSE ESS | ||
a 9.22360053 0.139119116 0.0021996664 0.0034632816 1000 | ||
bR -0.20196346 0.076106388 0.0012033477 0.0018370185 1000 | ||
bA -1.94430980 0.227080488 0.0035904578 0.0057840746 1000 | ||
bAR 0.39071684 0.131889143 0.0020853505 0.0032749642 1000 | ||
sigma 0.95036370 0.052161768 0.0008247500 0.0009204073 1000 | ||
Quantiles: | ||
2.5% 25.0% 50.0% 75.0% 97.5% | ||
a 8.95307475 9.12719750 9.2237750 9.31974000 9.490234250 | ||
bR -0.35217930 -0.25334425 -0.2012855 -0.15124725 -0.054216855 | ||
bA -2.39010825 -2.09894500 -1.9432550 -1.78643000 -1.513974250 | ||
bAR 0.13496995 0.30095575 0.3916590 0.47887625 0.650244475 | ||
sigma 0.85376115 0.91363250 0.9484920 0.98405750 1.058573750 | ||
"; | ||
``` | ||
|
||
Describe the posterior samples | ||
|
||
|
||
```julia | ||
describe(posterior2) | ||
``` | ||
|
||
Log evidence = 0.0 | ||
Iterations = 1001:2000 | ||
Thinning interval = 1 | ||
Chains = Chain1 | ||
Samples per chain = 1000 | ||
parameters = σ, βAR, βA, α, βR | ||
|
||
|
||
|
||
┌ Warning: `quantile(v::AbstractArray{<:Real})` is deprecated, use `quantile(v, [0.0, 0.25, 0.5, 0.75, 1.0])` instead. | ||
│ caller = (::getfield(MCMCChains, Symbol("##102#104")){Chains{Union{Missing, Float64},Float64,NamedTuple{(:parameters,),Tuple{Array{String,1}}},NamedTuple{(:hashedsummary,),Tuple{Base.RefValue{Tuple{UInt64,MCMCChains.ChainSummaries}}}}}})(::String) at none:0 | ||
└ @ MCMCChains ./none:0 | ||
|
||
|
||
[36m[1mEmpirical Posterior Estimates[22m[39m | ||
─────────────────────────────────────────── | ||
parameters | ||
Mean SD Naive SE MCSE ESS | ||
α 9.2258 0.1457 0.0046 0.0097 224.9403 | ||
βA -1.9580 0.2372 0.0075 0.0164 209.5267 | ||
βAR 0.3969 0.1330 0.0042 0.0100 175.8492 | ||
βR -0.2038 0.0797 0.0025 0.0055 212.8302 | ||
σ 0.9500 0.0509 0.0016 0.0017 916.8077 | ||
|
||
[36m[1mQuantiles[22m[39m | ||
─────────────────────────────────────────── | ||
parameters | ||
2.5% 25.0% 50.0% 75.0% 97.5% | ||
α 8.6704 9.1249 9.2246 9.3219 9.7520 | ||
βA -2.6922 -2.1177 -1.9737 -1.7884 -1.2476 | ||
βAR -0.0845 0.3022 0.3964 0.4840 0.8435 | ||
βR -0.4696 -0.2584 -0.2024 -0.1488 0.0564 | ||
σ 0.7990 0.9121 0.9494 0.9842 1.1318 | ||
|
||
|
||
|
||
end of 08/m8.1t.jl#- | ||
*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* |
Oops, something went wrong.