diff --git a/nimbleHMC/R/HMC_samplers.R b/nimbleHMC/R/HMC_samplers.R index 821d4bd..598ea0b 100644 --- a/nimbleHMC/R/HMC_samplers.R +++ b/nimbleHMC/R/HMC_samplers.R @@ -127,24 +127,70 @@ hmc_checkTarget <- function(model, targetNodes, hmcType) { ## checks for: ## - target with discrete or truncated distribution ## - dependencies with truncated, dinterval, or dconstraint distribution - calcNodes <- model$getDependencies(targetNodes, stochOnly = TRUE) - if(any(model$isDiscrete(targetNodes))) - stop(paste0(hmcType, ' sampler cannot operate on discrete-valued nodes: ', paste0(targetNodes[model$isDiscrete(targetNodes)], collapse = ', '))) - if(any(model$isTruncated(targetNodes))) - stop(paste0(hmcType, ' sampler cannot operate on nodes with truncated prior distributions: ', paste0(targetNodes[model$isTruncated(targetNodes)], collapse = ', '))) - if(any(model$isTruncated(calcNodes))) - stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have truncated distributions, which do not support AD calculations: ', paste0(calcNodes[model$isTruncated(calcNodes)], collapse = ', '))) - if(any(model$getDistribution(calcNodes) == 'dinterval')) - stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dinterval distributions, which do not support AD calculations: ', paste0(calcNodes[which(model$getDistribution(calcNodes) == 'dinterval')], collapse = ', '))) - if(any(model$getDistribution(calcNodes) == 'dconstraint')) - stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dconstraint distributions, which do not support AD calculations: ', paste0(calcNodes[which(model$getDistribution(calcNodes) == 'dconstraint')], collapse = ', '))) + ## + targetDeclIDs_unique <- unique(model$getDeclID(targetNodes)) + targetDeclInfo_unique <- model$getModelDef()$declInfo[targetDeclIDs_unique] + ## + ##if(any(model$isDiscrete(targetNodes))) + ## stop(paste0(hmcType, ' sampler cannot operate on discrete-valued nodes: ', paste0(targetNodes[model$isDiscrete(targetNodes)], collapse = ', '))) + targetDists_unique <- unique(sapply(targetDeclInfo_unique, function(x) x$getDistributionName())) + targetDiscreteBool <- sapply(targetDists_unique, isDiscrete) + if(any(targetDiscreteBool)) { + stop(paste0(hmcType, ' sampler cannot operate on nodes with discrete-valued distributions: ', paste0(targetDists_unique[targetDiscreteBool], collapse = ', '))) + } + ## + ##if(any(model$isTruncated(targetNodes))) + ## stop(paste0(hmcType, ' sampler cannot operate on nodes with truncated prior distributions: ', paste0(targetNodes[model$isTruncated(targetNodes)], collapse = ', '))) + targetTruncatedBool <- any(sapply(targetDeclInfo_unique, function(x) x$isTruncated())) + if(any(targetTruncatedBool)) { + targetExpanded <- model$expandNodeNames(targetNodes) + stop(paste0(hmcType, ' sampler cannot operate on nodes with truncated prior distributions: ', paste0(targetExpanded[model$isTruncated(targetExpanded)], collapse = ', '))) + } + ## + ##calcNodes <- model$getDependencies(targetNodes, stochOnly = TRUE) + depNodes <- model$getDependencies(targetNodes, self = FALSE, stochOnly = TRUE) + depDeclIDs_unique <- unique(model$getDeclID(depNodes)) + depDeclInfo_unique <- model$getModelDef()$declInfo[depDeclIDs_unique] + ## + ##if(any(model$isTruncated(calcNodes))) + ## stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have truncated distributions, which do not support AD calculations: ', paste0(calcNodes[model$isTruncated(calcNodes)], collapse = ', '))) + depTruncatedBool <- any(sapply(depDeclInfo_unique, function(x) x$isTruncated())) + if(any(depTruncatedBool)) { + stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have truncated distributions, which do not support AD calculations: ', paste0(depNodes[model$isTruncated(depNodes)], collapse = ', '))) + } + ## + depDists_unique <- sapply(depDeclInfo_unique, function(x) x$getDistributionName()) + ## + ##if(any(model$getDistribution(calcNodes) == 'dinterval')) + ## stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dinterval distributions, which do not support AD calculations: ', paste0(calcNodes[which(model$getDistribution(calcNodes) == 'dinterval')], collapse = ', '))) + depIntervalBool <- (depDists_unique == 'dinterval') + if(any(depIntervalBool)) { + stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dinterval distributions, which do not support AD calculations: ', paste0(depNodes[which(model$getDistribution(depNodes) == 'dinterval')], collapse = ', '))) + } + ## + ##if(any(model$getDistribution(calcNodes) == 'dconstraint')) + ## stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dconstraint distributions, which do not support AD calculations: ', paste0(calcNodes[which(model$getDistribution(calcNodes) == 'dconstraint')], collapse = ', '))) + depConstraintBool <- (depDists_unique == 'dconstraint') + if(any(depConstraintBool)) { + stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dconstraint distributions, which do not support AD calculations: ', paste0(depNodes[which(model$getDistribution(depNodes) == 'dconstraint')], collapse = ', '))) + } + ## ## next, check for: ## - target with user-defined distribution (without AD support) - dists <- model$getDistribution(targetNodes) + ## + ##dists <- model$getDistribution(targetNodes) + dists <- targetDists_unique ADok <- rep(TRUE, length(dists)) for(i in seq_along(dists)) { + ## + ## if/when modelDef$checkADsupportForDistribution() is added to core nimble, + ## change the entire body of this for-loop to instead be: + ## ADoak[i] <- model$getModelDef()$checkADsupportForDistribution(dists[i]) + ## ## these distributions get re-named to a nimble-version, and won't be found: if(dists[i] %in% c('dweib', 'dmnorm', 'dmvt', 'dwish', 'dinvwish')) next + ## + ## ## find the function or this distribution: nfObj <- get(dists[i], envir = parent.frame(4)) ## this took a bit of an investigation to make work ## is a user-defined distribution: