Skip to content

Commit

Permalink
reorder final lines of NUTS run (#73)
Browse files Browse the repository at this point in the history
* reorder final lines of NUTS run
* change also NUTS_classic and defensively run model$calculate before initEpsilon
* changes to indentation
  • Loading branch information
perrydv authored Dec 18, 2024
1 parent cc39fea commit e2d0033
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions nimbleHMC/R/HMC_samplers.R
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,9 @@ sampler_NUTS_classic <- nimbleFunction(
j <- j + 1
checkInterrupt()
}
if((timesRan <= nwarmup) & adaptive) adaptiveProcedure(btNL$a, btNL$na)
inverseTransformStoreCalculate(qNew)
nimCopy(from = model, to = mvSaved, row = 1, nodes = calcNodes, logProb = TRUE)
if((timesRan <= nwarmup) & adaptive) adaptiveProcedure(btNL$a, btNL$na)
},
methods = list(
drawMomentumValues = function() {
Expand Down Expand Up @@ -580,6 +580,7 @@ sampler_NUTS_classic <- nimbleFunction(
##for(i in 1:d) M[i] <<- 1 / warmupCovRegularized[i,i]
sqrtM <<- sqrt(M)
if(adaptEpsilon) {
inverseTransformStoreCalculate(qNew) ## defensively ensure model states are up to date
initEpsilon()
epsilonAdaptCount <<- 0
mu <<- log(10 * epsilon)
Expand Down Expand Up @@ -990,20 +991,23 @@ sampler_NUTS <- nimbleFunction(
accept_prob <- sum_metropolis_prob / n_leapfrog
copy_state(state_current, state_sample) ## extraneous copy? could remove?
##
inverseTransformStoreCalculate(state_sample$q)
nimCopy(from = model, to = mvSaved, row = 1, nodes = calcNodes, logProb = TRUE)
if((timesRan <= nwarmup) & adaptive) {
if(adaptEpsilon) adapt_stepsize(accept_prob)
update <- FALSE
if(adaptM) update <- adapt_M()
if(update & adaptEpsilon) {
if(initializeEpsilon) initEpsilon()
if(initializeEpsilon) {
inverseTransformStoreCalculate(state_sample$q) ## defensively ensure model states are up to date.
initEpsilon()
}
Hbar <<- 0
logEpsilonBar <<- 0
stepsizeCounter <<- 0
mu <<- log(10*epsilon)
}
}
inverseTransformStoreCalculate(state_sample$q)
nimCopy(from = model, to = mvSaved, row = 1, nodes = calcNodes, logProb = TRUE)
},
methods = list(
copy_state = function(to = stateNL(), from = stateNL()) {
Expand Down

0 comments on commit e2d0033

Please sign in to comment.