1- struct StanModel {M,R,C,N }
1+ struct StanResult {M,R,C}
22 model:: M
33 return_code:: R
44 chains:: C
5- cnames:: N
5+ end
6+
7+ function Base. show (io:: IO , mime:: MIME"text/plain" , res:: StanResult )
8+ show (io, mime, res. chains)
69end
710
811struct StanODEData
912end
1013
1114function generate_priors (n,priors)
1215 priors_string = " "
13- if priors== nothing
16+ if priors=== nothing
1417 for i in 1 : n
15- priors_string = string (priors_string," theta[ $i ] ~ normal(0, 1)" , " ; " )
18+ priors_string = string (priors_string," theta_ $i ~ normal(0, 1)" , " ; " )
1619 end
1720 else
1821 for i in 1 : n
19- priors_string = string (priors_string," theta[ $i ] ~ " ,stan_string (priors[i])," ;" )
22+ priors_string = string (priors_string," theta_ $i ~ " ,stan_string (priors[i])," ;" )
2023 end
2124 end
2225 priors_string
@@ -34,13 +37,13 @@ function generate_theta(n,priors)
3437 lower_bound = string (" lower=" ,minimum (priors[i]))
3538 end
3639 if lower_bound != " " && upper_bound != " "
37- theta = string (theta," real" ," <$lower_bound " ," ," ," $upper_bound >" ," theta $i " ," ;" )
40+ theta = string (theta," real" ," <$lower_bound " ," ," ," $upper_bound >" ," theta_ $i " ," ;" )
3841 elseif lower_bound != " "
39- theta = string (theta," real" ," <$lower_bound " ," >" ," theta $i " ," ;" )
42+ theta = string (theta," real" ," <$lower_bound " ," >" ," theta_ $i " ," ;" )
4043 elseif upper_bound != " "
41- theta = string (theta," real" ," <" ," $upper_bound >" ," theta $i " ," ;" )
44+ theta = string (theta," real" ," <" ," $upper_bound >" ," theta_ $i " ," ;" )
4245 else
43- theta = string (theta," real" ," theta $i " ," ;" )
46+ theta = string (theta," real" ," theta_ $i " ," ;" )
4447 end
4548 end
4649 return theta
@@ -50,9 +53,10 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
5053 stanmodel = nothing ;alg= :rk45 ,
5154 num_samples= 1000 , num_warmup= 1000 , reltol= 1e-3 ,
5255 abstol= 1e-6 , maxiter= Int (1e5 ),likelihood= Normal,
53- vars= (StanODEData (),InverseGamma (3 ,3 )),nchains= 1 ,
54- sample_u0 = false , save_idxs = nothing , diffeq_string = nothing , printsummary = true )
55-
56+ vars= (StanODEData (),InverseGamma (3 ,3 )),nchains= [1 ],
57+ sample_u0 = false , save_idxs = nothing , diffeq_string = nothing ,
58+ printsummary = true , output_format = :mcmcchains )
59+
5660 save_idxs != = nothing && length (save_idxs) == 1 ? save_idxs = save_idxs[1 ] : save_idxs = save_idxs
5761 length_of_y = length (prob. u0)
5862 save_idxs = something (save_idxs, 1 : length_of_y)
@@ -63,24 +67,26 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
6367 else
6468 length_of_parameter = length (prob. p) + sample_u0 * length (save_idxs)
6569 end
66-
70+
6771 if stanmodel === nothing
6872 if alg == :adams
69- algorithm = " integrate_ode_adams "
73+ algorithm = " ode_adams_tol "
7074 elseif alg == :rk45
71- algorithm = " integrate_ode_rk45 "
75+ algorithm = " ode_rk45_tol "
7276 elseif alg == :bdf
73- algorithm = " integrate_ode_bdf "
77+ algorithm = " ode_bdf_tol "
7478 else
7579 error (" The choices for alg are :adams, :rk45, or :bdf" )
7680 end
7781 hyper_params = " "
7882 tuple_hyper_params = " "
7983 setup_params = " "
8084 thetas = " "
85+ theta_names = " "
8186 theta_string = generate_theta (length_of_parameter,priors)
8287 for i in 1 : length_of_parameter
83- thetas = string (thetas," theta[$i ] = theta$i " ," ;" )
88+ thetas = string (thetas," real theta_$i " ," ;" )
89+ theta_names = string (theta_names," theta_$i " ," ," )
8490 end
8591 for i in 1 : length_of_params
8692 if isa (vars[i],StanODEData)
@@ -97,18 +103,18 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
97103 stan_likelihood = stan_string (likelihood)
98104 if sample_u0
99105 nu = length (save_idxs)
106+ dv_names_ind = findfirst (" $nu " , theta_names)[1 ]
100107 if nu < length (prob. u0)
101- u0 = " { "
108+ u0 = " "
102109 for u_ in prob. u0[nu+ 1 : length (prob. u0)]
103110 u0 = u0* string (u_)
104111 end
105- u0 = u0* " }"
106- integral_string = " u_hat = $algorithm (sho, append_array(theta[1:$nu ],$u0 ), t0, ts, theta[$(nu+ 1 ) :$length_of_parameter ], x_r, x_i, $reltol , $abstol , $maxiter );"
107- else
108- integral_string = " u_hat = $algorithm (sho, theta[1:$nu ], t0, ts, theta[$(nu+ 1 ) :$length_of_parameter ], x_r, x_i, $reltol , $abstol , $maxiter );"
112+ integral_string = " u_hat = $algorithm (sho, [$(theta_names[1 : dv_names_ind]) ,$u0 ]', t0, ts, $reltol , $abstol , $maxiter , $(rstrip (theta_names[dv_names_ind+ 2 : end ],' ,' )) );"
113+ else
114+ integral_string = " u_hat = $algorithm (sho, [$(theta_names[1 : dv_names_ind]) ]', t0, ts, $reltol , $abstol , $maxiter , $(rstrip (theta_names[dv_names_ind+ 2 : end ],' ,' )) );"
109115 end
110116 else
111- integral_string = " u_hat = $algorithm (sho, u0, t0, ts, theta, x_r, x_i, $reltol , $abstol , $maxiter );"
117+ integral_string = " u_hat = $algorithm (sho, u0, t0, ts, $reltol , $abstol , $maxiter , $( rstrip (theta_names, ' , ' )) );"
112118 end
113119 binsearch_string = """
114120 int bin_search(real x, int min_val, int max_val){
@@ -120,8 +126,8 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
120126 out = mid_pt;
121127 range = 0;
122128 } else {
123- range = (range + 1) / 2;
124- mid_pt = x > mid_pt ? mid_pt + range: mid_pt - range;
129+ range = (range + 1) / 2;
130+ mid_pt = x > mid_pt ? mid_pt + range: mid_pt - range;
125131 }
126132 }
127133 return out;
@@ -141,26 +147,18 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
141147 $diffeq_string
142148 }
143149 data {
144- real u0 [$length_of_y ];
150+ vector [$length_of_y ] u0 ;
145151 int<lower=1> T;
146152 real internal_var___u[T,$(length (save_idxs)) ];
147153 real t0;
148154 real ts[T];
149155 }
150- transformed data {
151- real x_r[0];
152- int x_i[0];
153- }
154156 parameters {
155157 $setup_params
156158 $theta_string
157159 }
158- transformed parameters{
159- real theta[$length_of_parameter ];
160- $thetas
161- }
162160 model{
163- real u_hat[T, $length_of_y ];
161+ vector[ $length_of_y ] u_hat[T];
164162 $hyper_params
165163 $priors_string
166164 $integral_string
@@ -169,9 +167,13 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
169167 }
170168 }
171169 "
172- stanmodel = CmdStan . Stanmodel (num_samples = num_samples, num_warmup = num_warmup, name = " parameter_estimation_model" , model = parameter_estimation_model, nchains= nchains, printsummary = printsummary)
170+ stanmodel = StanSample . SampleModel ( " parameter_estimation_model" , parameter_estimation_model, nchains; printsummary = printsummary, method = StanSample . Sample (;num_samples = num_samples, num_warmup = num_warmup) )
173171 end
174172 parameter_estimation_data = Dict (" u0" => prob. u0, " T" => length (t), " internal_var___u" => view (data, :, 1 : length (t))' , " t0" => prob. tspan[1 ], " ts" => t)
175- return_code, chains, cnames = CmdStan. stan (stanmodel, [parameter_estimation_data])
176- return StanModel (stanmodel, return_code, chains, cnames)
173+ rc = stan_sample (stanmodel; data = parameter_estimation_data)
174+ if success (rc)
175+ return StanResult (stanmodel, rc, read_samples (stanmodel; output_format= output_format))
176+ else
177+ rc. err
178+ end
177179end
0 commit comments