Skip to content

Commit ea7fa7c

Browse files
Merge pull request #216 from SciML/compathelper/new_version/2021-06-03-02-42-43-659-1397580942
CompatHelper: bump compat for "Turing" to "0.16"
2 parents 1b95523 + 9ed7f80 commit ea7fa7c

File tree

7 files changed

+113
-127
lines changed

7 files changed

+113
-127
lines changed

Project.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ authors = ["Vaibhavdixit02 <[email protected]>"]
44
version = "2.25.0"
55

66
[deps]
7-
ApproxBayes = "f5f396d3-230c-5e07-80e6-9fadf06146cc"
87
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
98
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
109
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
@@ -26,12 +25,12 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2625
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2726
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2827
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
28+
StanSample = "c1514b29-d3a0-5178-b312-660c88baa699"
2929
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
3030
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
3131
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3232

3333
[compat]
34-
ApproxBayes = "0.3"
3534
DiffEqBase = "6.36"
3635
DiffResults = "0.0.4, 1.0"
3736
Distances = "0.8, 0.9, 0.10"
@@ -46,14 +45,13 @@ Missings = "0.4, 1.0"
4645
ModelingToolkit = "5.6"
4746
Optim = "0.19, 0.20, 0.21, 0.22, 1.0"
4847
PDMats = "0.9, 0.10, 0.11"
49-
ParameterizedFunctions = "5"
5048
Parameters = "0.12"
5149
RecursiveArrayTools = "1,2"
5250
Reexport = "0.2, 1.0"
5351
Requires = "0.5, 1.0"
5452
StructArrays = "0.4, 0.5"
5553
TransformVariables = "0.3, 0.4"
56-
Turing = "0.12, 0.13, 0.14, 0.15"
54+
Turing = "0.12, 0.13, 0.14, 0.15, 0.16"
5755
julia = "1.3"
5856

5957
[extras]

src/DiffEqBayes.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,23 @@ using DocStringExtensions
77
using DiffEqBase, Distributions, Turing, MacroTools
88
using RecursiveArrayTools, ModelingToolkit
99
using Parameters, Distributions, Optim, Requires
10-
using Distances, ApproxBayes, DocStringExtensions, Random
10+
using Distances, DocStringExtensions, Random, StanSample
1111

1212
STANDARD_PROB_GENERATOR(prob,p) = remake(prob;u0=eltype(p).(prob.u0),p=p)
1313
STANDARD_PROB_GENERATOR(prob::EnsembleProblem,p) = EnsembleProblem(remake(prob.prob;u0=eltype(p).(prob.prob.u0),p=p))
1414

1515
include("turing_inference.jl")
16-
include("abc_inference.jl")
16+
# include("abc_inference.jl")
17+
include("stan_string.jl")
18+
include("stan_inference.jl")
1719

1820
function __init__()
19-
@require CmdStan="593b3428-ca2f-500c-ae53-031589ec8ddd" begin
20-
using .CmdStan
21-
include("stan_inference.jl")
22-
include("stan_string.jl")
23-
export stan_inference, stan_string
24-
end
25-
2621
@require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" begin
2722
using .DynamicHMC, TransformVariables, LogDensityProblems
2823
include("dynamichmc_inference.jl")
2924
export dynamichmc_inference
3025
end
3126
end
3227

33-
export turing_inference, abc_inference
34-
28+
export turing_inference, stan_inference ,abc_inference
3529
end # module

src/dynamichmc_inference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,5 +103,5 @@ function dynamichmc_inference(problem::DiffEqBase.DEProblem, algorithm, t, data,
103103
= TransformedLogDensity(trans, P)
104104
∇ℓ = LogDensityProblems.ADgradient(AD_gradient_kind, ℓ)
105105
results = mcmc_with_warmup(rng, ∇ℓ, num_samples; mcmc_kwargs...)
106-
merge((posterior = transform.(Ref(trans), results.chain), ), results)
106+
merge((posterior = TransformVariables.transform.(Ref(trans), results.chain), ), results)
107107
end

src/stan_inference.jl

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
1-
struct StanModel{M,R,C,N}
1+
struct StanResult{M,R,C}
22
model::M
33
return_code::R
44
chains::C
5-
cnames::N
5+
end
6+
7+
function Base.show(io::IO, mime::MIME"text/plain", res::StanResult)
8+
show(io, mime, res.chains)
69
end
710

811
struct StanODEData
912
end
1013

1114
function generate_priors(n,priors)
1215
priors_string = ""
13-
if priors==nothing
16+
if priors===nothing
1417
for i in 1:n
15-
priors_string = string(priors_string,"theta[$i] ~ normal(0, 1)", " ; ")
18+
priors_string = string(priors_string,"theta_$i ~ normal(0, 1)", " ; ")
1619
end
1720
else
1821
for i in 1:n
19-
priors_string = string(priors_string,"theta[$i] ~",stan_string(priors[i]),";")
22+
priors_string = string(priors_string,"theta_$i ~ ",stan_string(priors[i]),";")
2023
end
2124
end
2225
priors_string
@@ -34,13 +37,13 @@ function generate_theta(n,priors)
3437
lower_bound = string("lower=",minimum(priors[i]))
3538
end
3639
if lower_bound != "" && upper_bound != ""
37-
theta = string(theta,"real","<$lower_bound",",","$upper_bound>"," theta$i",";")
40+
theta = string(theta,"real","<$lower_bound",",","$upper_bound>"," theta_$i",";")
3841
elseif lower_bound != ""
39-
theta = string(theta,"real","<$lower_bound",">"," theta$i",";")
42+
theta = string(theta,"real","<$lower_bound",">"," theta_$i",";")
4043
elseif upper_bound != ""
41-
theta = string(theta,"real","<","$upper_bound>"," theta$i",";")
44+
theta = string(theta,"real","<","$upper_bound>"," theta_$i",";")
4245
else
43-
theta = string(theta,"real"," theta$i",";")
46+
theta = string(theta,"real"," theta_$i",";")
4447
end
4548
end
4649
return theta
@@ -50,9 +53,10 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
5053
stanmodel = nothing;alg=:rk45,
5154
num_samples=1000, num_warmup=1000, reltol=1e-3,
5255
abstol=1e-6, maxiter=Int(1e5),likelihood=Normal,
53-
vars=(StanODEData(),InverseGamma(3,3)),nchains=1,
54-
sample_u0 = false, save_idxs = nothing, diffeq_string = nothing, printsummary = true)
55-
56+
vars=(StanODEData(),InverseGamma(3,3)),nchains=[1],
57+
sample_u0 = false, save_idxs = nothing, diffeq_string = nothing,
58+
printsummary = true, output_format = :mcmcchains)
59+
5660
save_idxs !== nothing && length(save_idxs) == 1 ? save_idxs = save_idxs[1] : save_idxs = save_idxs
5761
length_of_y = length(prob.u0)
5862
save_idxs = something(save_idxs, 1:length_of_y)
@@ -63,24 +67,26 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
6367
else
6468
length_of_parameter = length(prob.p) + sample_u0 * length(save_idxs)
6569
end
66-
70+
6771
if stanmodel === nothing
6872
if alg ==:adams
69-
algorithm = "integrate_ode_adams"
73+
algorithm = "ode_adams_tol"
7074
elseif alg ==:rk45
71-
algorithm = "integrate_ode_rk45"
75+
algorithm = "ode_rk45_tol"
7276
elseif alg == :bdf
73-
algorithm = "integrate_ode_bdf"
77+
algorithm = "ode_bdf_tol"
7478
else
7579
error("The choices for alg are :adams, :rk45, or :bdf")
7680
end
7781
hyper_params = ""
7882
tuple_hyper_params = ""
7983
setup_params = ""
8084
thetas = ""
85+
theta_names = ""
8186
theta_string = generate_theta(length_of_parameter,priors)
8287
for i in 1:length_of_parameter
83-
thetas = string(thetas,"theta[$i] = theta$i",";")
88+
thetas = string(thetas,"real theta_$i",";")
89+
theta_names = string(theta_names,"theta_$i",",")
8490
end
8591
for i in 1:length_of_params
8692
if isa(vars[i],StanODEData)
@@ -97,18 +103,18 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
97103
stan_likelihood = stan_string(likelihood)
98104
if sample_u0
99105
nu = length(save_idxs)
106+
dv_names_ind = findfirst("$nu", theta_names)[1]
100107
if nu < length(prob.u0)
101-
u0 = "{"
108+
u0 = ""
102109
for u_ in prob.u0[nu+1:length(prob.u0)]
103110
u0 = u0*string(u_)
104111
end
105-
u0 = u0*"}"
106-
integral_string = "u_hat = $algorithm(sho, append_array(theta[1:$nu],$u0), t0, ts, theta[$(nu+1):$length_of_parameter], x_r, x_i, $reltol, $abstol, $maxiter);"
107-
else
108-
integral_string = "u_hat = $algorithm(sho, theta[1:$nu], t0, ts, theta[$(nu+1):$length_of_parameter], x_r, x_i, $reltol, $abstol, $maxiter);"
112+
integral_string = "u_hat = $algorithm(sho, [$(theta_names[1:dv_names_ind]),$u0]', t0, ts, $reltol, $abstol, $maxiter, $(rstrip(theta_names[dv_names_ind+2:end],',')));"
113+
else
114+
integral_string = "u_hat = $algorithm(sho, [$(theta_names[1:dv_names_ind])]', t0, ts, $reltol, $abstol, $maxiter, $(rstrip(theta_names[dv_names_ind+2:end],',')));"
109115
end
110116
else
111-
integral_string = "u_hat = $algorithm(sho, u0, t0, ts, theta, x_r, x_i, $reltol, $abstol, $maxiter);"
117+
integral_string = "u_hat = $algorithm(sho, u0, t0, ts, $reltol, $abstol, $maxiter, $(rstrip(theta_names,',')));"
112118
end
113119
binsearch_string = """
114120
int bin_search(real x, int min_val, int max_val){
@@ -120,8 +126,8 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
120126
out = mid_pt;
121127
range = 0;
122128
} else {
123-
range = (range + 1) / 2;
124-
mid_pt = x > mid_pt ? mid_pt + range: mid_pt - range;
129+
range = (range + 1) / 2;
130+
mid_pt = x > mid_pt ? mid_pt + range: mid_pt - range;
125131
}
126132
}
127133
return out;
@@ -141,26 +147,18 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
141147
$diffeq_string
142148
}
143149
data {
144-
real u0[$length_of_y];
150+
vector[$length_of_y] u0;
145151
int<lower=1> T;
146152
real internal_var___u[T,$(length(save_idxs))];
147153
real t0;
148154
real ts[T];
149155
}
150-
transformed data {
151-
real x_r[0];
152-
int x_i[0];
153-
}
154156
parameters {
155157
$setup_params
156158
$theta_string
157159
}
158-
transformed parameters{
159-
real theta[$length_of_parameter];
160-
$thetas
161-
}
162160
model{
163-
real u_hat[T,$length_of_y];
161+
vector[$length_of_y] u_hat[T];
164162
$hyper_params
165163
$priors_string
166164
$integral_string
@@ -169,9 +167,13 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
169167
}
170168
}
171169
"
172-
stanmodel = CmdStan.Stanmodel(num_samples=num_samples, num_warmup=num_warmup, name="parameter_estimation_model", model=parameter_estimation_model, nchains=nchains, printsummary = printsummary)
170+
stanmodel = StanSample.SampleModel("parameter_estimation_model", parameter_estimation_model, nchains; printsummary = printsummary, method = StanSample.Sample(;num_samples = num_samples, num_warmup = num_warmup))
173171
end
174172
parameter_estimation_data = Dict("u0"=>prob.u0, "T" => length(t), "internal_var___u" => view(data, :, 1:length(t))', "t0" => prob.tspan[1], "ts" => t)
175-
return_code, chains, cnames = CmdStan.stan(stanmodel, [parameter_estimation_data])
176-
return StanModel(stanmodel, return_code, chains, cnames)
173+
rc = stan_sample(stanmodel; data = parameter_estimation_data)
174+
if success(rc)
175+
return StanResult(stanmodel, rc, read_samples(stanmodel; output_format=output_format))
176+
else
177+
rc.err
178+
end
177179
end

test/runtests.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@ const GROUP = get(ENV, "GROUP", "All")
88
if GROUP == "All" || GROUP == "Core"
99
@time @safetestset "DynamicHMC" begin include("dynamicHMC.jl") end
1010
@time @safetestset "Turing" begin include("turing.jl") end
11-
@time @safetestset "ABC" begin include("abc.jl") end
11+
# @time @safetestset "ABC" begin include("abc.jl") end
1212
end
1313

1414
if GROUP == "Stan" || GROUP == "All"
15-
using Pkg
16-
Pkg.add("CmdStan")
1715
@time @safetestset "Stan_String" begin include("stan_string.jl") end
1816
@time @safetestset "Stan" begin include("stan.jl") end
1917
end

test/stan.jl

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using CmdStan, DiffEqBayes, OrdinaryDiffEq, ParameterizedFunctions,
1+
using DiffEqBayes, OrdinaryDiffEq, ParameterizedFunctions,
22
RecursiveArrayTools, Distributions, Test
33

44
println("One parameter case")
@@ -19,24 +19,21 @@ priors = [truncated(Normal(1.5,0.1),1.0,1.8)]
1919
bayesian_result = stan_inference(prob1,t,data,priors;num_samples=300,
2020
num_warmup=500,likelihood=Normal)
2121

22-
sdf = CmdStan.read_summary(bayesian_result.model)
23-
@test sdf[sdf.parameters .== :theta1, :mean][1] 1.5 atol=3e-1
22+
@test mean(get(bayesian_result.chains,:theta_1)[1]) 1.5 atol=3e-1
2423

2524
# Test norecompile
2625
bayesian_result2 = stan_inference(prob1,t,data,priors,bayesian_result.model;
2726
num_samples=300,num_warmup=500,likelihood=Normal)
2827

29-
sdf = CmdStan.read_summary(bayesian_result.model)
30-
@test sdf[sdf.parameters .== :theta1, :mean][1] 1.5 atol=3e-1
28+
@test mean(get(bayesian_result2.chains,:theta_1)[1]) 1.5 atol=3e-1
3129

3230
priors = [truncated(Normal(1.,0.01),0.5,2.0),truncated(Normal(1.,0.01),0.5,2.0),truncated(Normal(1.5,0.01),1.0,2.0)]
3331
bayesian_result = stan_inference(prob1,t,data,priors;num_samples=300,
3432
num_warmup=500,likelihood=Normal,sample_u0=true)
3533

36-
sdf = CmdStan.read_summary(bayesian_result.model)
37-
@test sdf[sdf.parameters .== :theta1, :mean][1] 1. atol=3e-1
38-
@test sdf[sdf.parameters .== :theta2, :mean][1] 1. atol=3e-1
39-
@test sdf[sdf.parameters .== :theta3, :mean][1] 1.5 atol=3e-1
34+
@test mean(get(bayesian_result.chains,:theta_1)[1]) 1. atol=3e-1
35+
@test mean(get(bayesian_result.chains,:theta_2)[1]) 1. atol=3e-1
36+
@test mean(get(bayesian_result.chains,:theta_3)[1]) 1.5 atol=3e-1
4037

4138
sol = solve(prob1,Tsit5(),save_idxs=[1])
4239
randomized = VectorOfArray([(sol(t[i]) + .01 * randn(1)) for i in 1:length(t)])
@@ -45,17 +42,15 @@ priors = [truncated(Normal(1.5,0.1),0.5,2)]
4542
bayesian_result = stan_inference(prob1,t,data,priors;num_samples=300,
4643
num_warmup=500,likelihood=Normal,save_idxs=[1])
4744

48-
sdf = CmdStan.read_summary(bayesian_result.model)
49-
@test sdf[sdf.parameters .== :theta1, :mean][1] 1.5 atol=3e-1
45+
@test mean(get(bayesian_result.chains,:theta_1)[1]) 1.5 atol=3e-1
5046

5147

5248
priors = [truncated(Normal(1.,0.01),0.5,2),truncated(Normal(1.5,0.01),0.5,2)]
5349
bayesian_result = stan_inference(prob1,t,data,priors;num_samples=300,
5450
num_warmup=500,likelihood=Normal,save_idxs=[1],sample_u0=true)
5551

56-
sdf = CmdStan.read_summary(bayesian_result.model)
57-
@test sdf[sdf.parameters .== :theta1, :mean][1] 1. atol=3e-1
58-
@test sdf[sdf.parameters .== :theta2, :mean][1] 1.5 atol=3e-1
52+
@test mean(get(bayesian_result.chains,:theta_1)[1]) 1. atol=3e-1
53+
@test mean(get(bayesian_result.chains,:theta_2)[1]) 1.5 atol=3e-1
5954

6055
println("Four parameter case")
6156
f1 = @ode_def begin
@@ -74,8 +69,7 @@ priors = [truncated(Normal(1.5,0.01),0.5,2),truncated(Normal(1.0,0.01),0.5,1.5),
7469
truncated(Normal(3.0,0.01),0.5,4),truncated(Normal(1.0,0.01),0.5,2)]
7570

7671
bayesian_result = stan_inference(prob1,t,data,priors;num_samples=100,num_warmup=500,vars =(DiffEqBayes.StanODEData(),InverseGamma(4,1)))
77-
sdf = CmdStan.read_summary(bayesian_result.model)
78-
@test sdf[sdf.parameters .== :theta1, :mean][1] 1.5 atol=1e-1
79-
@test sdf[sdf.parameters .== :theta2, :mean][1] 1.0 atol=1e-1
80-
@test sdf[sdf.parameters .== :theta3, :mean][1] 3.0 atol=1e-1
81-
@test sdf[sdf.parameters .== :theta4, :mean][1] 1.0 atol=1e-1
72+
@test mean(get(bayesian_result.chains,:theta_1)[1]) 1.5 atol=1e-1
73+
@test mean(get(bayesian_result.chains,:theta_2)[1]) 1.0 atol=1e-1
74+
@test mean(get(bayesian_result.chains,:theta_3)[1]) 3.0 atol=1e-1
75+
@test mean(get(bayesian_result.chains,:theta_4)[1]) 1.0 atol=1e-1

0 commit comments

Comments
 (0)