Skip to content

Commit

Permalink
Check target efficiency (#47)
Browse files Browse the repository at this point in the history
* working on efficiency of hmc_checkTarget

* working on checkTarget efficiency

* fixed bug using isDiscrete

* added comments

* updaetd comments
  • Loading branch information
danielturek authored Apr 25, 2024
1 parent 001497c commit 41c782e
Showing 1 changed file with 58 additions and 12 deletions.
70 changes: 58 additions & 12 deletions nimbleHMC/R/HMC_samplers.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,24 +127,70 @@ hmc_checkTarget <- function(model, targetNodes, hmcType) {
## checks for:
## - target with discrete or truncated distribution
## - dependencies with truncated, dinterval, or dconstraint distribution
calcNodes <- model$getDependencies(targetNodes, stochOnly = TRUE)
if(any(model$isDiscrete(targetNodes)))
stop(paste0(hmcType, ' sampler cannot operate on discrete-valued nodes: ', paste0(targetNodes[model$isDiscrete(targetNodes)], collapse = ', ')))
if(any(model$isTruncated(targetNodes)))
stop(paste0(hmcType, ' sampler cannot operate on nodes with truncated prior distributions: ', paste0(targetNodes[model$isTruncated(targetNodes)], collapse = ', ')))
if(any(model$isTruncated(calcNodes)))
stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have truncated distributions, which do not support AD calculations: ', paste0(calcNodes[model$isTruncated(calcNodes)], collapse = ', ')))
if(any(model$getDistribution(calcNodes) == 'dinterval'))
stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dinterval distributions, which do not support AD calculations: ', paste0(calcNodes[which(model$getDistribution(calcNodes) == 'dinterval')], collapse = ', ')))
if(any(model$getDistribution(calcNodes) == 'dconstraint'))
stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dconstraint distributions, which do not support AD calculations: ', paste0(calcNodes[which(model$getDistribution(calcNodes) == 'dconstraint')], collapse = ', ')))
##
targetDeclIDs_unique <- unique(model$getDeclID(targetNodes))
targetDeclInfo_unique <- model$getModelDef()$declInfo[targetDeclIDs_unique]
##
##if(any(model$isDiscrete(targetNodes)))
## stop(paste0(hmcType, ' sampler cannot operate on discrete-valued nodes: ', paste0(targetNodes[model$isDiscrete(targetNodes)], collapse = ', ')))
targetDists_unique <- unique(sapply(targetDeclInfo_unique, function(x) x$getDistributionName()))
targetDiscreteBool <- sapply(targetDists_unique, isDiscrete)
if(any(targetDiscreteBool)) {
stop(paste0(hmcType, ' sampler cannot operate on nodes with discrete-valued distributions: ', paste0(targetDists_unique[targetDiscreteBool], collapse = ', ')))
}
##
##if(any(model$isTruncated(targetNodes)))
## stop(paste0(hmcType, ' sampler cannot operate on nodes with truncated prior distributions: ', paste0(targetNodes[model$isTruncated(targetNodes)], collapse = ', ')))
targetTruncatedBool <- any(sapply(targetDeclInfo_unique, function(x) x$isTruncated()))
if(any(targetTruncatedBool)) {
targetExpanded <- model$expandNodeNames(targetNodes)
stop(paste0(hmcType, ' sampler cannot operate on nodes with truncated prior distributions: ', paste0(targetExpanded[model$isTruncated(targetExpanded)], collapse = ', ')))
}
##
##calcNodes <- model$getDependencies(targetNodes, stochOnly = TRUE)
depNodes <- model$getDependencies(targetNodes, self = FALSE, stochOnly = TRUE)
depDeclIDs_unique <- unique(model$getDeclID(depNodes))
depDeclInfo_unique <- model$getModelDef()$declInfo[depDeclIDs_unique]
##
##if(any(model$isTruncated(calcNodes)))
## stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have truncated distributions, which do not support AD calculations: ', paste0(calcNodes[model$isTruncated(calcNodes)], collapse = ', ')))
depTruncatedBool <- any(sapply(depDeclInfo_unique, function(x) x$isTruncated()))
if(any(depTruncatedBool)) {
stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have truncated distributions, which do not support AD calculations: ', paste0(depNodes[model$isTruncated(depNodes)], collapse = ', ')))
}
##
depDists_unique <- sapply(depDeclInfo_unique, function(x) x$getDistributionName())
##
##if(any(model$getDistribution(calcNodes) == 'dinterval'))
## stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dinterval distributions, which do not support AD calculations: ', paste0(calcNodes[which(model$getDistribution(calcNodes) == 'dinterval')], collapse = ', ')))
depIntervalBool <- (depDists_unique == 'dinterval')
if(any(depIntervalBool)) {
stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dinterval distributions, which do not support AD calculations: ', paste0(depNodes[which(model$getDistribution(depNodes) == 'dinterval')], collapse = ', ')))
}
##
##if(any(model$getDistribution(calcNodes) == 'dconstraint'))
## stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dconstraint distributions, which do not support AD calculations: ', paste0(calcNodes[which(model$getDistribution(calcNodes) == 'dconstraint')], collapse = ', ')))
depConstraintBool <- (depDists_unique == 'dconstraint')
if(any(depConstraintBool)) {
stop(paste0(hmcType, ' sampler cannot operate since these dependent nodes have dconstraint distributions, which do not support AD calculations: ', paste0(depNodes[which(model$getDistribution(depNodes) == 'dconstraint')], collapse = ', ')))
}
##
## next, check for:
## - target with user-defined distribution (without AD support)
dists <- model$getDistribution(targetNodes)
##
##dists <- model$getDistribution(targetNodes)
dists <- targetDists_unique
ADok <- rep(TRUE, length(dists))
for(i in seq_along(dists)) {
##
## if/when modelDef$checkADsupportForDistribution() is added to core nimble,
## change the entire body of this for-loop to instead be:
## ADoak[i] <- model$getModelDef()$checkADsupportForDistribution(dists[i])
##
## these distributions get re-named to a nimble-version, and won't be found:
if(dists[i] %in% c('dweib', 'dmnorm', 'dmvt', 'dwish', 'dinvwish')) next
##
##
## find the function or this distribution:
nfObj <- get(dists[i], envir = parent.frame(4)) ## this took a bit of an investigation to make work
## is a user-defined distribution:
Expand Down

0 comments on commit 41c782e

Please sign in to comment.