Skip to content

Commit

Permalink
added checks for target nodes of NUTS samplers
Browse files Browse the repository at this point in the history
  • Loading branch information
danielturek committed Jan 21, 2024
1 parent faa8c1f commit 7d36fee
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 21 deletions.
60 changes: 39 additions & 21 deletions nimbleHMC/R/HMC_samplers.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,38 @@ sampler_langevin <- nimbleFunction(
)


hmc_checkWarmup <- function(warmupMode, warmup, samplerName) {
if(!(warmupMode %in% c('default', 'burnin', 'fraction', 'iterations'))) stop('`warmupMode` control argument of ', samplerName, ' sampler must have value "default", "burnin", "fraction", or "iterations". The value provided was: ', warmupMode, '.', call. = FALSE)

hmc_checkTarget <- function(model, targetNodes, calcNodes, hmcType) {
## checks for:
## - target with discrete or truncated distribution
## - target with user-defined distribution (without AD support)
## - dependencies with truncated, dinterval, or dconstraint distribution
if(any(model$isDiscrete(targetNodes)))
stop(paste0(hmcType, ' sampler cannot operate on discrete-valued nodes: ', paste0(targetNodes[model$isDiscrete(targetNodes)], collapse = ', ')))
if(any(model$isTruncated(targetNodes)))
stop(paste0(hmcType, ' sampler cannot operate on nodes with truncated prior distributions: ', paste0(targetNodes[model$isTruncated(targetNodes)], collapse = ', ')))
if(any(model$isTruncated(calcNodes)))
stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have truncated distributions, which do not support AD calculations: ', paste0(calcNodes[model$isTruncated(calcNodes)], collapse = ', ')))
if(any(model$getDistribution(calcNodes) == 'dinterval'))
stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dinterval distributions, which do not support AD calculations: ', paste0(calcNodes[which(model$getDistribution(calcNodes) == 'dinterval')], collapse = ', ')))
if(any(model$getDistribution(calcNodes) == 'dconstraint'))
stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dconstraint distributions, which do not support AD calculations: ', paste0(calcNodes[which(model$getDistribution(calcNodes) == 'dconstraint')], collapse = ', ')))
}



hmc_checkWarmup <- function(warmupMode, warmup, hmcType) {
if(!(warmupMode %in% c('default', 'burnin', 'fraction', 'iterations'))) stop('`warmupMode` control argument of ', hmcType, ' sampler must have value "default", "burnin", "fraction", or "iterations". The value provided was: ', warmupMode, '.', call. = FALSE)
if(warmupMode == 'fraction')
if(!is.numeric(warmup) | warmup < 0 | warmup > 1) stop('When the `warmupMode` control argument of ', samplerName, ' sampler is "fraction", the `warmup` control argument must be a number between 0 and 1, which will specify the fraction of the total MCMC iterations to use as warmup. The value provided for the `warmup` control argument was: ', warmup, '.', call. = FALSE)
if(!is.numeric(warmup) | warmup < 0 | warmup > 1) stop('When the `warmupMode` control argument of ', hmcType, ' sampler is "fraction", the `warmup` control argument must be a number between 0 and 1, which will specify the fraction of the total MCMC iterations to use as warmup. The value provided for the `warmup` control argument was: ', warmup, '.', call. = FALSE)
if(warmupMode == 'iterations')
if(!is.numeric(warmup) | warmup < 0 | floor(warmup) != warmup) stop('When the `warmupMode` control argument of ', samplerName, ' sampler is "iterations", the `warmup` control argument must be a non-negative integer, which will specify the number MCMC iterations to use as warmup. The value provided for the `warmup` control argument was: ', warmup, '.', call. = FALSE)
if(!is.numeric(warmup) | warmup < 0 | floor(warmup) != warmup) stop('When the `warmupMode` control argument of ', hmcType, ' sampler is "iterations", the `warmup` control argument must be a non-negative integer, which will specify the number MCMC iterations to use as warmup. The value provided for the `warmup` control argument was: ', warmup, '.', call. = FALSE)
}



hmc_setWarmup <- nimbleFunction(
setup = function(warmupMode, warmup, messages, samplerName, targetNodesToPrint) {},
setup = function(warmupMode, warmup, messages, hmcType, targetNodesToPrint) {},
run = function(MCMCniter = double(), MCMCnburnin = double(), adaptive = logical()) {
##
## set nwarmup
Expand All @@ -148,29 +168,29 @@ hmc_setWarmup <- nimbleFunction(
## informative message
if(messages) {
if(!adaptive) { ## adaptive = FALSE
print(' [Note] ', samplerName, ' sampler (nodes: ', targetNodesToPrint, ') has adaptation turned off,\n so no warmup period will be used.')
print(' [Note] ', hmcType, ' sampler (nodes: ', targetNodesToPrint, ') has adaptation turned off,\n so no warmup period will be used.')
} else { ## adaptive = TRUE
if(warmupMode == 'default') {
if(MCMCnburnin > 0) print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'default' and `nburnin` > 0,\n the number of warmup iterations is equal to `nburnin`.\n The burnin samples will be discarded, and all samples returned will be post-warmup.")
else print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'default' and `nburnin` = 0,\n the number of warmup iterations is equal to `niter/2`.\n No samples will be discarded, so the first half of the samples returned\n are from the warmup period, and the second half of the samples are post-warmup.")
if(MCMCnburnin > 0) print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'default' and `nburnin` > 0,\n the number of warmup iterations is equal to `nburnin`.\n The burnin samples will be discarded, and all samples returned will be post-warmup.")
else print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'default' and `nburnin` = 0,\n the number of warmup iterations is equal to `niter/2`.\n No samples will be discarded, so the first half of the samples returned\n are from the warmup period, and the second half of the samples are post-warmup.")
}
if(warmupMode == 'burnin')
if(MCMCnburnin > 0) print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'burnin', the number of warmup iterations is equal to `nburnin`.\n The burnin samples will be discarded, and all samples returned will be post-warmup.")
if(MCMCnburnin > 0) print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'burnin', the number of warmup iterations is equal to `nburnin`.\n The burnin samples will be discarded, and all samples returned will be post-warmup.")
else
print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using 0 warmup iterations.\n No adaptation is being done, apart from initialization of epsilon\n (if `initializeEpsilon` is TRUE).")
print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using 0 warmup iterations.\n No adaptation is being done, apart from initialization of epsilon\n (if `initializeEpsilon` is TRUE).")
if(warmupMode == 'fraction') {
if(MCMCnburnin < nwarmup) print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'fraction', the number of warmup iterations is equal to\n `niter*fraction`, where `fraction` is the value of the `warmup` control argument.\n Because `nburnin` is less than the number of warmup iterations,\n some of the samples returned will be collected during the warmup period,\n and the remainder of the samples returned will be post-warmup.")
else print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'fraction', the number of warmup iterations is equal to\n `niter*fraction`, where `fraction` is the value of the warmup `control` argument.\n Because `nburnin` exceeds the number of warmup iterations,\n all samples returned will be post-warmup.")
if(MCMCnburnin < nwarmup) print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'fraction', the number of warmup iterations is equal to\n `niter*fraction`, where `fraction` is the value of the `warmup` control argument.\n Because `nburnin` is less than the number of warmup iterations,\n some of the samples returned will be collected during the warmup period,\n and the remainder of the samples returned will be post-warmup.")
else print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'fraction', the number of warmup iterations is equal to\n `niter*fraction`, where `fraction` is the value of the warmup `control` argument.\n Because `nburnin` exceeds the number of warmup iterations,\n all samples returned will be post-warmup.")
}
if(warmupMode == 'iterations')
if(MCMCnburnin < nwarmup) print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'iterations', the number of warmup iterations\n is the value of the `warmup` control argument.\n Because `nburnin` is less than the number of warmup iterations,\n some of the samples returned will be collected during the warmup period,\n and the remainder of the samples returned will be post-warmup.")
else print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'iterations', the number of warmup iterations\n is the value of the `warmup` control argument.\n Because `nburnin` exceeds the number of warmup iterations,\n all samples returned will be post-warmup.")
if(MCMCnburnin < nwarmup) print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'iterations', the number of warmup iterations\n is the value of the `warmup` control argument.\n Because `nburnin` is less than the number of warmup iterations,\n some of the samples returned will be collected during the warmup period,\n and the remainder of the samples returned will be post-warmup.")
else print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'iterations', the number of warmup iterations\n is the value of the `warmup` control argument.\n Because `nburnin` exceeds the number of warmup iterations,\n all samples returned will be post-warmup.")
}
}
##
## hard check that nwarmup >= 20
if(adaptive & nwarmup > 0 & nwarmup < 20) {
print(' [Error] ', samplerName, ' sampler (nodes: ', targetNodesToPrint, ') requires a minimum of 20 warmup iterations.')
print(' [Error] ', hmcType, ' sampler (nodes: ', targetNodesToPrint, ') requires a minimum of 20 warmup iterations.')
stop()
}
##
Expand Down Expand Up @@ -290,8 +310,8 @@ sampler_NUTS_classic <- nimbleFunction(
targetNodesToPrint <- paste(targetNodes, collapse = ', ')
if(nchar(targetNodesToPrint) > 100) targetNodesToPrint <- paste0(substr(targetNodesToPrint, 1, 97), '...')
calcNodes <- model$getDependencies(targetNodes)
## check for discrete nodes (early, before parameterTransform is specialized)
if(any(model$isDiscrete(targetNodesAsScalars))) stop(paste0('NUTS_classic sampler cannot operate on discrete-valued nodes: ', paste0(targetNodesAsScalars[model$isDiscrete(targetNodesAsScalars)], collapse = ', ')))
## check validity of target and dependent nodes (early, before parameterTransform is specialized)
hmc_checkTarget(model, targetNodes, calcNodes, 'NUTS_classic')
## processing of bounds and transformations
my_parameterTransform <- parameterTransform(model, targetNodesAsScalars)
d <- my_parameterTransform$getTransformedLength()
Expand Down Expand Up @@ -757,10 +777,8 @@ sampler_NUTS <- nimbleFunction(
targetNodesToPrint <- paste(targetNodes, collapse = ', ')
if(nchar(targetNodesToPrint) > 100) targetNodesToPrint <- paste0(substr(targetNodesToPrint, 1, 97), '...')
calcNodes <- model$getDependencies(targetNodes)
## check for discrete nodes (early, before parameterTransform is specialized)
if(any(model$isDiscrete(targetNodesAsScalars)))
stop(paste0('NUTS sampler cannot operate on discrete-valued nodes: ',
paste0(targetNodesAsScalars[model$isDiscrete(targetNodesAsScalars)], collapse = ', ')))
## check validity of target and dependent nodes (early, before parameterTransform is specialized)
hmc_checkTarget(model, targetNodes, calcNodes, 'NUTS')
## processing of bounds and transformations
my_parameterTransform <- parameterTransform(model, targetNodesAsScalars)
d <- my_parameterTransform$getTransformedLength()
Expand Down
38 changes: 38 additions & 0 deletions nimbleHMC/tests/testthat/test-HMC.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,44 @@ test_that('HMC sampler error messages for transformations with non-constant boun
})


test_that('hmc_checkTarget catches all invalid cases', {
code <- nimbleCode({
x[1] ~ dbern(0.5)
x[2] ~ dbin(size = 4, prob = 0.5)
x[3] ~ dcat(prob = p[1:3])
x[4] ~ dpois(2)
x[5:7] ~ dmulti(prob = p[1:3], size = 3)
x[8] ~ T(dnorm(0, 1), 0, )
x[9] ~ T(dnorm(0, 1), , 2)
x[10] ~ T(dnorm(0, 1), 0, 2)
##
a[1] ~ dnorm(0, 1)
b[1] ~ T(dnorm(a[1], 1), 0, 2)
##
a[2] ~ dnorm(0, 1)
b[2] ~ dconstraint(a[2] > 0)
##
a[3] ~ dnorm(0, 1)
b[3] ~ dinterval(a[3], 0)
})
constants <- list(p = rep(1/3,3))
inits <- list(x = rep(1, 10), a = rep(1, 3))
data <- list(b = rep(1, 3))
Rmodel <- nimbleModel(code, constants, data, inits)
##
conf <- configureMCMC(Rmodel, nodes = NULL, print = FALSE)
##
for(node in Rmodel$expandNodeNames(c('x', 'a'))) {
conf$setSamplers()
conf$addSampler(target = node, type = 'NUTS_classic')
expect_error(buildMCMC(conf))
conf$setSamplers()
conf$addSampler(target = node, type = 'NUTS')
expect_error(buildMCMC(conf))
}
})


test_that('HMC sampler error messages for invalid M mass matrix arguments', {
code <- nimbleCode({
for(i in 1:5) x[i] ~ dnorm(0, 1)
Expand Down

0 comments on commit 7d36fee

Please sign in to comment.