From 7d36feee6e38fa808e6e884ac499d63a75f12041 Mon Sep 17 00:00:00 2001 From: Daniel Turek Date: Sun, 21 Jan 2024 18:48:36 -0500 Subject: [PATCH] added checks for target nodes of NUTS samplers --- nimbleHMC/R/HMC_samplers.R | 60 +++++++++++++++++++---------- nimbleHMC/tests/testthat/test-HMC.R | 38 ++++++++++++++++++ 2 files changed, 77 insertions(+), 21 deletions(-) diff --git a/nimbleHMC/R/HMC_samplers.R b/nimbleHMC/R/HMC_samplers.R index bacec82..c722041 100644 --- a/nimbleHMC/R/HMC_samplers.R +++ b/nimbleHMC/R/HMC_samplers.R @@ -122,18 +122,38 @@ sampler_langevin <- nimbleFunction( ) -hmc_checkWarmup <- function(warmupMode, warmup, samplerName) { - if(!(warmupMode %in% c('default', 'burnin', 'fraction', 'iterations'))) stop('`warmupMode` control argument of ', samplerName, ' sampler must have value "default", "burnin", "fraction", or "iterations". The value provided was: ', warmupMode, '.', call. = FALSE) + +hmc_checkTarget <- function(model, targetNodes, calcNodes, hmcType) { + ## checks for: + ## - target with discrete or truncated distribution + ## - target with user-defined distribution (without AD support) + ## - dependencies with truncated, dinterval, or dconstraint distribution + if(any(model$isDiscrete(targetNodes))) + stop(paste0(hmcType, ' sampler cannot operate on discrete-valued nodes: ', paste0(targetNodes[model$isDiscrete(targetNodes)], collapse = ', '))) + if(any(model$isTruncated(targetNodes))) + stop(paste0(hmcType, ' sampler cannot operate on nodes with truncated prior distributions: ', paste0(targetNodes[model$isTruncated(targetNodes)], collapse = ', '))) + if(any(model$isTruncated(calcNodes))) + stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have truncated distributions, which do not support AD calculations: ', paste0(calcNodes[model$isTruncated(calcNodes)], collapse = ', '))) + if(any(model$getDistribution(calcNodes) == 'dinterval')) + stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dinterval distributions, which do not support AD calculations: ', paste0(calcNodes[which(model$getDistribution(calcNodes) == 'dinterval')], collapse = ', '))) + if(any(model$getDistribution(calcNodes) == 'dconstraint')) + stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dconstraint distributions, which do not support AD calculations: ', paste0(calcNodes[which(model$getDistribution(calcNodes) == 'dconstraint')], collapse = ', '))) +} + + + +hmc_checkWarmup <- function(warmupMode, warmup, hmcType) { + if(!(warmupMode %in% c('default', 'burnin', 'fraction', 'iterations'))) stop('`warmupMode` control argument of ', hmcType, ' sampler must have value "default", "burnin", "fraction", or "iterations". The value provided was: ', warmupMode, '.', call. = FALSE) if(warmupMode == 'fraction') - if(!is.numeric(warmup) | warmup < 0 | warmup > 1) stop('When the `warmupMode` control argument of ', samplerName, ' sampler is "fraction", the `warmup` control argument must be a number between 0 and 1, which will specify the fraction of the total MCMC iterations to use as warmup. The value provided for the `warmup` control argument was: ', warmup, '.', call. = FALSE) + if(!is.numeric(warmup) | warmup < 0 | warmup > 1) stop('When the `warmupMode` control argument of ', hmcType, ' sampler is "fraction", the `warmup` control argument must be a number between 0 and 1, which will specify the fraction of the total MCMC iterations to use as warmup. The value provided for the `warmup` control argument was: ', warmup, '.', call. = FALSE) if(warmupMode == 'iterations') - if(!is.numeric(warmup) | warmup < 0 | floor(warmup) != warmup) stop('When the `warmupMode` control argument of ', samplerName, ' sampler is "iterations", the `warmup` control argument must be a non-negative integer, which will specify the number MCMC iterations to use as warmup. The value provided for the `warmup` control argument was: ', warmup, '.', call. = FALSE) + if(!is.numeric(warmup) | warmup < 0 | floor(warmup) != warmup) stop('When the `warmupMode` control argument of ', hmcType, ' sampler is "iterations", the `warmup` control argument must be a non-negative integer, which will specify the number MCMC iterations to use as warmup. The value provided for the `warmup` control argument was: ', warmup, '.', call. = FALSE) } hmc_setWarmup <- nimbleFunction( - setup = function(warmupMode, warmup, messages, samplerName, targetNodesToPrint) {}, + setup = function(warmupMode, warmup, messages, hmcType, targetNodesToPrint) {}, run = function(MCMCniter = double(), MCMCnburnin = double(), adaptive = logical()) { ## ## set nwarmup @@ -148,29 +168,29 @@ hmc_setWarmup <- nimbleFunction( ## informative message if(messages) { if(!adaptive) { ## adaptive = FALSE - print(' [Note] ', samplerName, ' sampler (nodes: ', targetNodesToPrint, ') has adaptation turned off,\n so no warmup period will be used.') + print(' [Note] ', hmcType, ' sampler (nodes: ', targetNodesToPrint, ') has adaptation turned off,\n so no warmup period will be used.') } else { ## adaptive = TRUE if(warmupMode == 'default') { - if(MCMCnburnin > 0) print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'default' and `nburnin` > 0,\n the number of warmup iterations is equal to `nburnin`.\n The burnin samples will be discarded, and all samples returned will be post-warmup.") - else print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'default' and `nburnin` = 0,\n the number of warmup iterations is equal to `niter/2`.\n No samples will be discarded, so the first half of the samples returned\n are from the warmup period, and the second half of the samples are post-warmup.") + if(MCMCnburnin > 0) print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'default' and `nburnin` > 0,\n the number of warmup iterations is equal to `nburnin`.\n The burnin samples will be discarded, and all samples returned will be post-warmup.") + else print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'default' and `nburnin` = 0,\n the number of warmup iterations is equal to `niter/2`.\n No samples will be discarded, so the first half of the samples returned\n are from the warmup period, and the second half of the samples are post-warmup.") } if(warmupMode == 'burnin') - if(MCMCnburnin > 0) print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'burnin', the number of warmup iterations is equal to `nburnin`.\n The burnin samples will be discarded, and all samples returned will be post-warmup.") + if(MCMCnburnin > 0) print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'burnin', the number of warmup iterations is equal to `nburnin`.\n The burnin samples will be discarded, and all samples returned will be post-warmup.") else - print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using 0 warmup iterations.\n No adaptation is being done, apart from initialization of epsilon\n (if `initializeEpsilon` is TRUE).") + print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using 0 warmup iterations.\n No adaptation is being done, apart from initialization of epsilon\n (if `initializeEpsilon` is TRUE).") if(warmupMode == 'fraction') { - if(MCMCnburnin < nwarmup) print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'fraction', the number of warmup iterations is equal to\n `niter*fraction`, where `fraction` is the value of the `warmup` control argument.\n Because `nburnin` is less than the number of warmup iterations,\n some of the samples returned will be collected during the warmup period,\n and the remainder of the samples returned will be post-warmup.") - else print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'fraction', the number of warmup iterations is equal to\n `niter*fraction`, where `fraction` is the value of the warmup `control` argument.\n Because `nburnin` exceeds the number of warmup iterations,\n all samples returned will be post-warmup.") + if(MCMCnburnin < nwarmup) print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'fraction', the number of warmup iterations is equal to\n `niter*fraction`, where `fraction` is the value of the `warmup` control argument.\n Because `nburnin` is less than the number of warmup iterations,\n some of the samples returned will be collected during the warmup period,\n and the remainder of the samples returned will be post-warmup.") + else print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'fraction', the number of warmup iterations is equal to\n `niter*fraction`, where `fraction` is the value of the warmup `control` argument.\n Because `nburnin` exceeds the number of warmup iterations,\n all samples returned will be post-warmup.") } if(warmupMode == 'iterations') - if(MCMCnburnin < nwarmup) print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'iterations', the number of warmup iterations\n is the value of the `warmup` control argument.\n Because `nburnin` is less than the number of warmup iterations,\n some of the samples returned will be collected during the warmup period,\n and the remainder of the samples returned will be post-warmup.") - else print(" [Note] ", samplerName, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'iterations', the number of warmup iterations\n is the value of the `warmup` control argument.\n Because `nburnin` exceeds the number of warmup iterations,\n all samples returned will be post-warmup.") + if(MCMCnburnin < nwarmup) print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'iterations', the number of warmup iterations\n is the value of the `warmup` control argument.\n Because `nburnin` is less than the number of warmup iterations,\n some of the samples returned will be collected during the warmup period,\n and the remainder of the samples returned will be post-warmup.") + else print(" [Note] ", hmcType, " sampler (nodes: ", targetNodesToPrint, ") is using ", nwarmup, " warmup iterations.\n Since `warmupMode` is 'iterations', the number of warmup iterations\n is the value of the `warmup` control argument.\n Because `nburnin` exceeds the number of warmup iterations,\n all samples returned will be post-warmup.") } } ## ## hard check that nwarmup >= 20 if(adaptive & nwarmup > 0 & nwarmup < 20) { - print(' [Error] ', samplerName, ' sampler (nodes: ', targetNodesToPrint, ') requires a minimum of 20 warmup iterations.') + print(' [Error] ', hmcType, ' sampler (nodes: ', targetNodesToPrint, ') requires a minimum of 20 warmup iterations.') stop() } ## @@ -290,8 +310,8 @@ sampler_NUTS_classic <- nimbleFunction( targetNodesToPrint <- paste(targetNodes, collapse = ', ') if(nchar(targetNodesToPrint) > 100) targetNodesToPrint <- paste0(substr(targetNodesToPrint, 1, 97), '...') calcNodes <- model$getDependencies(targetNodes) - ## check for discrete nodes (early, before parameterTransform is specialized) - if(any(model$isDiscrete(targetNodesAsScalars))) stop(paste0('NUTS_classic sampler cannot operate on discrete-valued nodes: ', paste0(targetNodesAsScalars[model$isDiscrete(targetNodesAsScalars)], collapse = ', '))) + ## check validity of target and dependent nodes (early, before parameterTransform is specialized) + hmc_checkTarget(model, targetNodes, calcNodes, 'NUTS_classic') ## processing of bounds and transformations my_parameterTransform <- parameterTransform(model, targetNodesAsScalars) d <- my_parameterTransform$getTransformedLength() @@ -757,10 +777,8 @@ sampler_NUTS <- nimbleFunction( targetNodesToPrint <- paste(targetNodes, collapse = ', ') if(nchar(targetNodesToPrint) > 100) targetNodesToPrint <- paste0(substr(targetNodesToPrint, 1, 97), '...') calcNodes <- model$getDependencies(targetNodes) - ## check for discrete nodes (early, before parameterTransform is specialized) - if(any(model$isDiscrete(targetNodesAsScalars))) - stop(paste0('NUTS sampler cannot operate on discrete-valued nodes: ', - paste0(targetNodesAsScalars[model$isDiscrete(targetNodesAsScalars)], collapse = ', '))) + ## check validity of target and dependent nodes (early, before parameterTransform is specialized) + hmc_checkTarget(model, targetNodes, calcNodes, 'NUTS') ## processing of bounds and transformations my_parameterTransform <- parameterTransform(model, targetNodesAsScalars) d <- my_parameterTransform$getTransformedLength() diff --git a/nimbleHMC/tests/testthat/test-HMC.R b/nimbleHMC/tests/testthat/test-HMC.R index 62bd010..2f0047e 100644 --- a/nimbleHMC/tests/testthat/test-HMC.R +++ b/nimbleHMC/tests/testthat/test-HMC.R @@ -107,6 +107,44 @@ test_that('HMC sampler error messages for transformations with non-constant boun }) +test_that('hmc_checkTarget catches all invalid cases', { + code <- nimbleCode({ + x[1] ~ dbern(0.5) + x[2] ~ dbin(size = 4, prob = 0.5) + x[3] ~ dcat(prob = p[1:3]) + x[4] ~ dpois(2) + x[5:7] ~ dmulti(prob = p[1:3], size = 3) + x[8] ~ T(dnorm(0, 1), 0, ) + x[9] ~ T(dnorm(0, 1), , 2) + x[10] ~ T(dnorm(0, 1), 0, 2) + ## + a[1] ~ dnorm(0, 1) + b[1] ~ T(dnorm(a[1], 1), 0, 2) + ## + a[2] ~ dnorm(0, 1) + b[2] ~ dconstraint(a[2] > 0) + ## + a[3] ~ dnorm(0, 1) + b[3] ~ dinterval(a[3], 0) + }) + constants <- list(p = rep(1/3,3)) + inits <- list(x = rep(1, 10), a = rep(1, 3)) + data <- list(b = rep(1, 3)) + Rmodel <- nimbleModel(code, constants, data, inits) + ## + conf <- configureMCMC(Rmodel, nodes = NULL, print = FALSE) + ## + for(node in Rmodel$expandNodeNames(c('x', 'a'))) { + conf$setSamplers() + conf$addSampler(target = node, type = 'NUTS_classic') + expect_error(buildMCMC(conf)) + conf$setSamplers() + conf$addSampler(target = node, type = 'NUTS') + expect_error(buildMCMC(conf)) + } +}) + + test_that('HMC sampler error messages for invalid M mass matrix arguments', { code <- nimbleCode({ for(i in 1:5) x[i] ~ dnorm(0, 1)